Mind Lab Toolkit (MinT)
CustomizeConcepts

Loss Functions

MinT 为 forward_backward() 内置了几种 loss function,覆盖监督学习和 policy gradient 训练。每种 loss 都对 model 预测 token 的 logprob 算梯度,按 importance、advantage 或自定义 reward 信号加权。

Concept

MinT 的 loss function 分两类:

监督学习(SFT 和 DPO)

  • loss_fn="cross_entropy" —— 标准的 token 预测 loss。让 model 学会匹配 target token。用在 SFT 和多轮监督数据上。

Policy gradient(RL)

  • loss_fn="importance_sampling" —— off-policy RL。把轨迹按 importance_weight = target_logprob / behavior_logprob 重新加权。对分布漂移敏感,最适合搭配小学习率。
  • loss_fn="ppo" —— on-policy RL。把概率比裁剪到 [1 - clip_ratio, 1 + clip_ratio] 区间,防止单步更新过大。稳,用得最广。
  • loss_fn="cispo" —— 保守 importance sampling。IS 的软裁剪变体;裁剪的是 logprob 差,不是概率比。介于 IS 和 PPO 之间。

每个 RL loss 都接收 advantages(每条 trajectory 一个 advantage 值),可选 behavior_logprobs(行为 policy 下的 logprob)。梯度优化的是新 policy 下的期望回报。

Pattern

import mint
from mint import types
import torch

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
)
tokenizer = training_client.get_tokenizer()

# 示例 1:SFT(监督学习)
prompt_text = "The capital of France is"
target_text = " Paris."
prompt_ids = tokenizer.encode(prompt_text)
target_ids = tokenizer.encode(target_text)
all_ids = prompt_ids + target_ids

model_input_sft = types.ModelInput.from_ints(all_ids[:-1])
sft_target_tokens = all_ids[1:]
sft_weights = [0] * len(prompt_ids) + [1] * len(target_ids)

datum_sft = types.Datum(
    model_input=model_input_sft,
    loss_fn_inputs={
        "target_tokens": sft_target_tokens,
        "weights": sft_weights,
    },
)

result = training_client.forward_backward([datum_sft], loss_fn="cross_entropy").result()
print(f"SFT loss: {result.loss:.4f}")

# 示例 2:RL(policy gradient)
# 模拟 model 生成的 token 和它们的 logprob
model_tokens = [7741, 34651, 31410]
model_logprobs = [-0.5, -0.3, -0.8]
advantages = [0.15]  # 单条 trajectory,正 reward

model_input_rl = types.ModelInput.from_ints(model_tokens)

datum_rl = types.Datum(
    model_input=model_input_rl,
    loss_fn_inputs={
        "target_tokens": model_tokens,
        "logprobs": model_logprobs,
        "advantages": advantages,
    },
)

result_rl = training_client.forward_backward([datum_rl], loss_fn="ppo").result()
print(f"PPO loss: {result_rl.loss:.4f}")

# 走一步 optimizer
adam_params = types.AdamParams(learning_rate=1e-4)
training_client.optim_step(adam_params).result()

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/quickstart/custom_loss.py

API Surface

Lossloss_fn_inputs用途
"cross_entropy"target_tokensweightsSFT、多轮监督
"importance_sampling"target_tokenslogprobsadvantagesoff-policy RL、offline RL
"ppo"target_tokenslogprobsadvantageson-policy RL、稳定的梯度更新
"cispo"target_tokenslogprobsadvantages保守 RL、平滑裁剪

RL loss 的可选参数:

  • clip_ratio(PPO)—— 裁剪范围 [1 - clip_ratio, 1 + clip_ratio]。默认 0.2。范围 0.0–0.5
  • beta(CISPO)—— 软裁剪锐度。默认 1.0,越大裁剪越锐,越接近硬 PPO。

Caveats & Pitfalls

  • 数据格式loss_fn_inputs 的 key 必须严格匹配 loss function 的要求。传 target_logprobs 而漏传 logprobs 会抛 KeyError。
  • Advantage 范围:RL 的 advantage 应该在每个 batch 内做零中心化。用 baseline 或 value function 估计 advantage,不要直接拿原始 reward。
  • Importance sampling 漂移loss_fn="importance_sampling" 对分布漂移敏感。如果 policy 偏离 behavior policy 太远,importance weight 会爆炸。改用 loss_fn="ppo"(带裁剪)或缩小学习率可以缓解。
  • SFT 的 weight mask:SFT 时 prompt token 的 weights=0,completion token 的 weights=1。renderer 的 build_supervised_example() 会自动处理这一步。
  • Off-policy vs on-policyimportance_sampling 接受任意 policy 产生的数据(真正的 off-policy)。ppocispo 假设 trajectory 由当前 model 生成(on-policy 或近似 on-policy)。按数据来源选。

本页目录