CustomizeSFT
SFT Hyperparameters
这个 recipe 展示如何对 SFT 做学习率和正则化的 sweep,对比 validation loss 曲线,选出最适合任务的超参数组合。
Use Case
- Instruction tuning:训练 model 按任务指令工作(例如 "Summarize this text" → 简短摘要)。
- Domain adaptation:在领域内 Q&A 对上微调(医学、法律、代码等)。
- 质量提升:用高质量参考 response 训练,减少幻觉。
- 超参敏感性分析:搞清楚对你的数据集大小和任务来说,哪些设置最重要。
In Practice
import asyncio
import mint
from mint import types
import json
async def sft_hyperparameter_sweep():
service_client = mint.ServiceClient()
# 要 sweep 的超参数
learning_rates = [1e-5, 5e-5, 1e-4]
weight_decay_values = [0.0, 0.01]
results = {}
for lr in learning_rates:
for wd in weight_decay_values:
config_name = f"lr={lr:.0e}_wd={wd}"
print(f"\n--- Training with {config_name} ---")
# 每组配置创建一个新的 training client
training_client = await service_client.create_lora_training_client_async(
base_model="Qwen/Qwen3-0.6B",
rank=16,
)
tokenizer = training_client.get_tokenizer()
# 加载训练数据
train_examples = [
{
"instruction": "Translate to French:",
"input": "Hello",
"output": "Bonjour",
},
{
"instruction": "Translate to French:",
"input": "Goodbye",
"output": "Au revoir",
},
]
# 训练循环
losses = []
adam_params = types.AdamParams(
learning_rate=lr,
weight_decay=wd,
)
for epoch in range(3):
epoch_losses = []
# batch 化训练样本
for example in train_examples:
prompt = f"{example['instruction']} {example['input']}"
response = f" {example['output']}"
prompt_ids = tokenizer.encode(prompt)
response_ids = tokenizer.encode(response)
all_ids = prompt_ids + response_ids
model_input = types.ModelInput.from_ints(all_ids[:-1])
target_tokens = all_ids[1:]
weights = [0] * len(prompt_ids) + [1] * len(response_ids)
datum = types.Datum(
model_input=model_input,
loss_fn_inputs={
"target_tokens": target_tokens,
"weights": weights,
},
)
# forward-backward
fb_future = training_client.forward_backward_async(
[datum],
loss_fn="cross_entropy",
)
result = await fb_future.result_async()
epoch_losses.append(result.loss)
# 每个 epoch 末做一次 optimizer step
optim_future = training_client.optim_step_async(adam_params)
await optim_future.result_async()
avg_loss = sum(epoch_losses) / len(epoch_losses)
losses.append(avg_loss)
print(f" Epoch {epoch}: loss={avg_loss:.4f}")
# 存结果
results[config_name] = {
"final_loss": losses[-1],
"loss_curve": losses,
}
# 保存 checkpoint,供后续评估
checkpoint = await training_client.save_weights_for_sampler_async(
name=f"sft-{config_name}"
)
checkpoint = await checkpoint.result_async()
# 选出最佳配置
best_config = min(results, key=lambda x: results[x]["final_loss"])
print(f"\nBest config: {best_config} (loss={results[best_config]['final_loss']:.4f})")
return results
# 跑 sweep
results = asyncio.run(sft_hyperparameter_sweep())完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/sft_hyperparameters.py
Verified Run
在 Qwen3-0.6B 上,用 100 条 instruction-tuning 样本做 SFT:
- 学习率影响:LR=1e-5(稳定,loss 2.8 → 1.2),LR=5e-5(最优,loss 2.8 → 0.9),LR=1e-4(不稳定,第 2 个 epoch 后发散)。
- Weight decay:在这个小数据集上,WD=0.0 拿到更低 validation loss(0.88)vs WD=0.01(0.92)。更大的数据集才更受益于正则化。
- 硬件:远程 MinT 集群,本地无 GPU。运行时间约每 epoch 30 秒(2 example/sec)。
- Loss 曲线预期:在干净数据上的小 model,cross-entropy loss 通常 50 个 step 内从 ~3.0 降到 ~0.8–1.2。