Mind Lab Toolkit (MinT)
CustomizeSFT

SFT 概览

监督微调(SFT,supervised fine-tuning)让 language model 学会根据给定的 prompt 输出目标 response。你提供带标签的 (prompt, response) 对,MinT 更新 model weights,让 response token 上的预测 loss 最小化 —— prompt token 的 loss weight 设为 0,不参与梯度计算。

Configuration

SFT 需要 ServiceClient、LoRA training client 和一个 optimizer 配置。最小化设置:

import mint
from mint import types

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

tokenizer = training_client.get_tokenizer()
adam_params = types.AdamParams(learning_rate=5e-5)

环境变量:

  • MINT_API_KEY:MinT API key(必需)。从 macaron.im/mindlab 申请。
  • MINT_BASE_URL:MinT 服务端 endpoint。默认 https://mint.macaron.xin/(中国大陆:mint-cn.macaron.xin)。
  • MINT_BASE_MODEL:base model 名称。默认 Qwen/Qwen3-0.6B
  • MINT_LORA_RANK:LoRA rank。默认 16

所有训练都在远程 MinT 服务端上跑。你的 Python 脚本调 forward_backward(...),然后在 .result() 上阻塞等 batch 算完。

Prompting Guide

SFT 只在 response 部分算 loss。从 (prompt, response) 对构造 Datum 对象的步骤:

  1. 分别 encode prompt 和 response。
  2. 拼接 token ID。
  3. prompt token 的 loss weight 设为 0,response token 设为 1.0

quickstart.py 里的标准写法:

def process_sft_example(ex: dict, tokenizer) -> types.Datum:
    # ex = {"question": "What is 3 * 4?"}
    a, b = map(int, re.findall(r"\d+", ex["question"]))
    answer = str(a * b)
    prompt = f"Question: {ex['question']}\nAnswer:"
    completion = f" {answer}"

    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(completion, add_special_tokens=False)
    completion_tokens.append(tokenizer.eos_token_id)

    all_tokens = prompt_tokens + completion_tokens
    all_weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)

    # teacher forcing 移位:input = all_tokens[:-1], target = all_tokens[1:]
    input_tokens = all_tokens[:-1]
    target_tokens = all_tokens[1:]
    weights = all_weights[1:]

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
    )

关键点:

  • Prompt 永远被 mask 掉loss_weight=0.0),梯度不流过它。
  • Response 永远参与训练loss_weight=1.0),让 model 学着预测目标 token。
  • 聊天风格的 instruction tuning 用 tokenizer 的 apply_chat_template(...) 方法(见 Rendering)。

使用 Renderer

你也可以用 mint.recipe 的 renderer 把 chat 格式的 message 直接转成 Datum,不需要手动 encode token。Renderer 会帮你处理 chat template、loss mask 和 teacher-forcing 移位:

import mint.recipe as recipe

renderer = recipe.renderers.get_renderer(
    recipe.get_recommended_renderer_name("Qwen/Qwen3-0.6B"),
    tokenizer,
)

messages = [
    {"role": "user", "content": "What is 3 * 4?"},
    {"role": "assistant", "content": "12"},
]

model_input, weights = renderer.build_supervised_example(messages)
datum = recipe.datum_from_model_input_weights(model_input, weights, max_length=2048)

build_supervised_example() 返回 (ModelInput, weights),prompt token 的 weight 为 0,response token 为 1。datum_from_model_input_weights() 把它们包装成带正确 loss_fn_inputs 结构的 Datum

对于多轮对话,renderer 默认只在最后一条 assistant 消息上算 loss。传 train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES 可以在所有 assistant 回合上训练。

Output Format

SFT 不直接生成文本,它更新 model weights。每次 forward_backward step 之后会拿到一个 ForwardBackwardResult,里面包含:

  • loss_fn_outputs:一个 per-datum dict 的列表,包含 logprob 等 loss function 元数据。
  • metrics:聚合的训练指标(如果可用)。

要让训练好的 model 生成文本,先保存 LoRA weights 并创建一个 sampling client:

# SFT 训练完成后:
checkpoint = training_client.save_weights_and_get_sampling_client(name="my-sft-v1")

prompt_ids = tokenizer.encode("3 * 7 =")
samples = checkpoint.sample(
    prompt=types.ModelInput.from_ints(prompt_ids),
    sampling_params=types.SamplingParams(max_tokens=16, temperature=0.7),
    num_samples=4,
)

for seq in samples.sequences:
    print(tokenizer.decode(seq.tokens))

系统性的评估 —— hold-out 测试集、benchmark 指标、采样日志 —— 见 Concepts → Evaluations

All Parameters

参数类型默认值含义
base_modelstr"Qwen/Qwen3-0.6B"base model 的 Hugging Face model ID。
rankint16LoRA rank。值越大越有表达力,显存也越多。典型值 8–64。
train_mlpboolTrue训练 MLP(feed-forward)层。
train_attnboolTrue训练 attention 层。
train_unembedboolTrue训练 unembedding(输出)层。
loss_fnstr"cross_entropy"loss function。SFT 永远用 "cross_entropy",没有备选。
learning_ratefloat5e-5Adam 学习率。instruction tuning 典型值 1e-5 到 1e-4。
betastuple[float, float](0.9, 0.999)Adam 一阶 / 二阶矩的指数衰减率。
epsfloat1e-8Adam 数值稳定项。
weight_decayfloat0.0L2 正则化系数。LoRA 典型值 0.0–0.01。

用法:

adam_params = types.AdamParams(
    learning_rate=5e-5,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.0,
)

for step, batch in enumerate(batches):
    result = training_client.forward_backward(batch, loss_fn="cross_entropy").result()
    training_client.optim_step(adam_params).result()
    print(f"Step {step}: metrics={result.metrics}")

Tinker 兼容性提醒:

  • 不要调 zero_grad_async() —— MinT 服务端会自动清梯度。
  • loss_fn 参数传给 forward_backward(...),不放在 AdamParams 里。
  • save_weights_for_sampler(...)save_weights_and_get_sampling_client(...) 语义一致:都是把 LoRA weights 序列化。

本页目录