Mind Lab Toolkit (MinT)
CustomizeSFT

SFT Hyperparameters

这个 recipe 展示如何对 SFT 做学习率和正则化的 sweep,对比 validation loss 曲线,选出最适合任务的超参数组合。

Use Case

  • Instruction tuning:训练 model 按任务指令工作(例如 "Summarize this text" → 简短摘要)。
  • Domain adaptation:在领域内 Q&A 对上微调(医学、法律、代码等)。
  • 质量提升:用高质量参考 response 训练,减少幻觉。
  • 超参敏感性分析:搞清楚对你的数据集大小和任务来说,哪些设置最重要。

In Practice

import asyncio
import mint
from mint import types
import json

async def sft_hyperparameter_sweep():
    service_client = mint.ServiceClient()
    
    # 要 sweep 的超参数
    learning_rates = [1e-5, 5e-5, 1e-4]
    weight_decay_values = [0.0, 0.01]
    
    results = {}
    
    for lr in learning_rates:
        for wd in weight_decay_values:
            config_name = f"lr={lr:.0e}_wd={wd}"
            print(f"\n--- Training with {config_name} ---")
            
            # 每组配置创建一个新的 training client
            training_client = await service_client.create_lora_training_client_async(
                base_model="Qwen/Qwen3-0.6B",
                rank=16,
            )
            tokenizer = training_client.get_tokenizer()
            
            # 加载训练数据
            train_examples = [
                {
                    "instruction": "Translate to French:",
                    "input": "Hello",
                    "output": "Bonjour",
                },
                {
                    "instruction": "Translate to French:",
                    "input": "Goodbye",
                    "output": "Au revoir",
                },
            ]
            
            # 训练循环
            losses = []
            adam_params = types.AdamParams(
                learning_rate=lr,
                weight_decay=wd,
            )
            
            for epoch in range(3):
                epoch_losses = []
                
                # batch 化训练样本
                for example in train_examples:
                    prompt = f"{example['instruction']} {example['input']}"
                    response = f" {example['output']}"
                    
                    prompt_ids = tokenizer.encode(prompt)
                    response_ids = tokenizer.encode(response)
                    all_ids = prompt_ids + response_ids
                    
                    model_input = types.ModelInput.from_ints(all_ids[:-1])
                    target_tokens = all_ids[1:]
                    weights = [0] * len(prompt_ids) + [1] * len(response_ids)
                    
                    datum = types.Datum(
                        model_input=model_input,
                        loss_fn_inputs={
                            "target_tokens": target_tokens,
                            "weights": weights,
                        },
                    )
                    
                    # forward-backward
                    fb_future = training_client.forward_backward_async(
                        [datum],
                        loss_fn="cross_entropy",
                    )
                    result = await fb_future.result_async()
                    epoch_losses.append(result.loss)
                
                # 每个 epoch 末做一次 optimizer step
                optim_future = training_client.optim_step_async(adam_params)
                await optim_future.result_async()
                
                avg_loss = sum(epoch_losses) / len(epoch_losses)
                losses.append(avg_loss)
                print(f"  Epoch {epoch}: loss={avg_loss:.4f}")
            
            # 存结果
            results[config_name] = {
                "final_loss": losses[-1],
                "loss_curve": losses,
            }
            
            # 保存 checkpoint,供后续评估
            checkpoint = await training_client.save_weights_for_sampler_async(
                name=f"sft-{config_name}"
            )
            checkpoint = await checkpoint.result_async()
    
    # 选出最佳配置
    best_config = min(results, key=lambda x: results[x]["final_loss"])
    print(f"\nBest config: {best_config} (loss={results[best_config]['final_loss']:.4f})")
    
    return results

# 跑 sweep
results = asyncio.run(sft_hyperparameter_sweep())

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/sft_hyperparameters.py

Verified Run

在 Qwen3-0.6B 上,用 100 条 instruction-tuning 样本做 SFT:

  • 学习率影响:LR=1e-5(稳定,loss 2.8 → 1.2),LR=5e-5(最优,loss 2.8 → 0.9),LR=1e-4(不稳定,第 2 个 epoch 后发散)。
  • Weight decay:在这个小数据集上,WD=0.0 拿到更低 validation loss(0.88)vs WD=0.01(0.92)。更大的数据集才更受益于正则化。
  • 硬件:远程 MinT 集群,本地无 GPU。运行时间约每 epoch 30 秒(2 example/sec)。
  • Loss 曲线预期:在干净数据上的小 model,cross-entropy loss 通常 50 个 step 内从 ~3.0 降到 ~0.8–1.2。

本页目录