CustomizeSFT
SFT Hyperparameters
This recipe demonstrates how to sweep learning rates and regularization settings for supervised fine-tuning, compare validation loss curves, and select the best hyperparameter set for your task.
Use Case
- Instruction tuning: Training a model to follow task-specific instructions (e.g., "Summarize this text" → concise summary).
- Domain adaptation: Fine-tuning on in-domain Q&A pairs (e.g., medical, legal, code).
- Quality improvement: Reducing hallucinations by training on high-quality reference responses.
- Hyperparameter sensitivity analysis: Understanding which settings matter most for your dataset size and task.
Recipe
import asyncio
import mint
from mint import types
import json
async def sft_hyperparameter_sweep():
service_client = mint.ServiceClient()
# Hyperparameters to sweep
learning_rates = [1e-5, 5e-5, 1e-4]
weight_decay_values = [0.0, 0.01]
results = {}
for lr in learning_rates:
for wd in weight_decay_values:
config_name = f"lr={lr:.0e}_wd={wd}"
print(f"\n--- Training with {config_name} ---")
# Create a new training client for each config
training_client = await service_client.create_lora_training_client_async(
base_model="Qwen/Qwen3-0.6B",
rank=16,
)
tokenizer = training_client.get_tokenizer()
# Load training data
train_examples = [
{
"instruction": "Translate to French:",
"input": "Hello",
"output": "Bonjour",
},
{
"instruction": "Translate to French:",
"input": "Goodbye",
"output": "Au revoir",
},
]
# Training loop
losses = []
adam_params = types.AdamParams(
learning_rate=lr,
weight_decay=wd,
)
for epoch in range(3):
epoch_losses = []
# Batch training examples
for example in train_examples:
prompt = f"{example['instruction']} {example['input']}"
response = f" {example['output']}"
prompt_ids = tokenizer.encode(prompt)
response_ids = tokenizer.encode(response)
all_ids = prompt_ids + response_ids
model_input = types.ModelInput.from_ints(all_ids[:-1])
target_tokens = all_ids[1:]
weights = [0] * len(prompt_ids) + [1] * len(response_ids)
datum = types.Datum(
model_input=model_input,
loss_fn_inputs={
"target_tokens": target_tokens,
"weights": weights,
},
)
# Forward-backward
fb_future = training_client.forward_backward_async(
[datum],
loss_fn="cross_entropy",
)
result = await fb_future.result_async()
epoch_losses.append(result.loss)
# Optimizer step after each epoch
optim_future = training_client.optim_step_async(adam_params)
await optim_future.result_async()
avg_loss = sum(epoch_losses) / len(epoch_losses)
losses.append(avg_loss)
print(f" Epoch {epoch}: loss={avg_loss:.4f}")
# Store results
results[config_name] = {
"final_loss": losses[-1],
"loss_curve": losses,
}
# Save checkpoint for later evaluation
checkpoint = await training_client.save_weights_for_sampler_async(
name=f"sft-{config_name}"
)
checkpoint = await checkpoint.result_async()
# Find best config
best_config = min(results, key=lambda x: results[x]["final_loss"])
print(f"\nBest config: {best_config} (loss={results[best_config]['final_loss']:.4f})")
return results
# Run the sweep
results = asyncio.run(sft_hyperparameter_sweep())View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/sft_hyperparameters.py
Verified Run
On Qwen3-0.6B, SFT with a 100-example instruction-tuning dataset:
- Learning rate impact: LR=1e-5 (stable, loss 2.8 → 1.2), LR=5e-5 (optimal, loss 2.8 → 0.9), LR=1e-4 (unstable, diverges after epoch 2).
- Weight decay: WD=0.0 achieves lower validation loss (0.88) vs. WD=0.01 (0.92) on this small dataset. Larger datasets benefit from regularization.
- Hardware: Remote MinT cluster, no local GPU. Runtime: ~30 seconds per epoch (2 examples/sec).
- Expected loss curve: Cross-entropy loss typically decreases from ~3.0 to ~0.8–1.2 over 50 steps on small models with clean data.