Mind Lab Toolkit (MinT)
CustomizeSFT

SFT 概览

监督微调(SFT,supervised fine-tuning)让 language model 学会根据给定的 prompt 输出目标 response。你提供带标签的 (prompt, response) 对,MinT 更新 model weights,让 response token 上的预测 loss 最小化 —— prompt token 的 loss weight 设为 0,不参与梯度计算。

Configuration

SFT 需要 ServiceClient、LoRA training client 和一个 optimizer 配置。最小化设置:

import mint
from mint import types

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

tokenizer = training_client.get_tokenizer()
adam_params = types.AdamParams(learning_rate=5e-5)

环境变量:

  • MINT_API_KEY:MinT API key(必需)。从 macaron.im/mindlab 申请。
  • MINT_BASE_URL:MinT 服务端 endpoint。默认 https://mint.macaron.xin/(中国大陆:mint-cn.macaron.xin)。
  • MINT_BASE_MODEL:base model 名称。默认 Qwen/Qwen3-0.6B
  • MINT_LORA_RANK:LoRA rank。默认 16

所有训练都在远程 MinT 服务端上跑。你的 Python 脚本调 forward_backward(...),然后在 .result() 上阻塞等 batch 算完。

Prompting Guide

SFT 只在 response 部分算 loss。从 (prompt, response) 对构造 Datum 对象的步骤:

  1. 分别 encode prompt 和 response。
  2. 拼接 token ID。
  3. prompt token 的 loss weight 设为 0,response token 设为 1.0

quickstart.py 里的标准写法:

def process_sft_example(ex: dict, tokenizer) -> types.Datum:
    # ex = {"question": "What is 3 * 4?"}
    a, b = map(int, re.findall(r"\d+", ex["question"]))
    answer = str(a * b)
    prompt = f"Question: {ex['question']}\nAnswer:"
    completion = f" {answer}"

    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(completion, add_special_tokens=False)
    completion_tokens.append(tokenizer.eos_token_id)

    all_tokens = prompt_tokens + completion_tokens
    all_weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)

    # teacher forcing 移位:input = all_tokens[:-1], target = all_tokens[1:]
    input_tokens = all_tokens[:-1]
    target_tokens = all_tokens[1:]
    weights = all_weights[1:]

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
    )

关键点:

  • Prompt 永远被 mask 掉loss_weight=0.0),梯度不流过它。
  • Response 永远参与训练loss_weight=1.0),让 model 学着预测目标 token。
  • 聊天风格的 instruction tuning 用 tokenizer 的 apply_chat_template(...) 方法(见 Rendering)。

Output Format

SFT 不直接生成文本,它更新 model weights。每次 forward_backward step 之后会拿到一个 ForwardBackwardResult,里面包含:

  • loss:加权 cross-entropy loss(标量)。
  • loss_fn_outputs:每个 datum 的元数据(logprob 等)。

要让训练好的 model 生成文本,先保存 LoRA weights 并创建一个 sampling client:

# SFT 训练完成后:
checkpoint = training_client.save_weights_and_get_sampling_client(name="my-sft-v1")

prompt_ids = tokenizer.encode("3 * 7 =")
samples = checkpoint.sample(
    prompt=types.ModelInput.from_ints(prompt_ids),
    sampling_params=types.SamplingParams(max_tokens=16, temperature=0.7),
    num_samples=4,
)

for seq in samples.sequences:
    print(tokenizer.decode(seq.tokens))

系统性的评估 —— hold-out 测试集、benchmark 指标、采样日志 —— 见 Concepts → Evaluations

All Parameters

参数类型默认值含义
base_modelstr"Qwen/Qwen3-0.6B"base model 的 Hugging Face model ID。
rankint16LoRA rank。值越大越有表达力,显存也越多。典型值 8–64。
train_mlpboolTrue训练 MLP(feed-forward)层。
train_attnboolTrue训练 attention 层。
train_unembedboolTrue训练 unembedding(输出)层。
loss_fnstr"cross_entropy"loss function。SFT 永远用 "cross_entropy",没有备选。
learning_ratefloat5e-5Adam 学习率。instruction tuning 典型值 1e-5 到 1e-4。
betastuple[float, float](0.9, 0.999)Adam 一阶 / 二阶矩的指数衰减率。
epsfloat1e-8Adam 数值稳定项。
weight_decayfloat0.0L2 正则化系数。LoRA 典型值 0.0–0.01。

用法:

adam_params = types.AdamParams(
    learning_rate=5e-5,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.0,
)

for step, batch in enumerate(batches):
    result = training_client.forward_backward(batch, loss_fn="cross_entropy").result()
    training_client.optim_step(adam_params).result()
    print(f"Step {step}: loss={result.loss:.4f}")

Tinker 兼容性提醒:

  • 不要调 zero_grad_async() —— MinT 服务端会自动清梯度。
  • loss_fn 参数传给 forward_backward(...),不放在 AdamParams 里。
  • save_weights_for_sampler(...)save_weights_and_get_sampling_client(...) 语义一致:都是把 LoRA weights 序列化。

本页目录