CustomizeConcepts
Loss Functions
MinT 为 forward_backward() 内置了几种 loss function,覆盖监督学习和 policy gradient 训练。每种 loss 都对 model 预测 token 的 logprob 算梯度,按 importance、advantage 或自定义 reward 信号加权。
Concept
MinT 的 loss function 分两类:
监督学习(SFT 和 DPO)
loss_fn="cross_entropy"—— 标准的 token 预测 loss。让 model 学会匹配 target token。用在 SFT 和多轮监督数据上。
Policy gradient(RL)
loss_fn="importance_sampling"—— off-policy RL。把轨迹按importance_weight = target_logprob / behavior_logprob重新加权。对分布漂移敏感,最适合搭配小学习率。loss_fn="ppo"—— on-policy RL。把概率比裁剪到[1 - clip_ratio, 1 + clip_ratio]区间,防止单步更新过大。稳,用得最广。loss_fn="cispo"—— 保守 importance sampling。IS 的软裁剪变体;裁剪的是 logprob 差,不是概率比。介于 IS 和 PPO 之间。
每个 RL loss 都接收 advantages(每条 trajectory 一个 advantage 值),可选 behavior_logprobs(行为 policy 下的 logprob)。梯度优化的是新 policy 下的期望回报。
Pattern
import mint
from mint import types
import torch
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-0.6B",
rank=16,
)
tokenizer = training_client.get_tokenizer()
# 示例 1:SFT(监督学习)
prompt_text = "The capital of France is"
target_text = " Paris."
prompt_ids = tokenizer.encode(prompt_text)
target_ids = tokenizer.encode(target_text)
all_ids = prompt_ids + target_ids
model_input_sft = types.ModelInput.from_ints(all_ids[:-1])
sft_target_tokens = all_ids[1:]
sft_weights = [0] * len(prompt_ids) + [1] * len(target_ids)
datum_sft = types.Datum(
model_input=model_input_sft,
loss_fn_inputs={
"target_tokens": sft_target_tokens,
"weights": sft_weights,
},
)
result = training_client.forward_backward([datum_sft], loss_fn="cross_entropy").result()
print(f"SFT loss: {result.loss:.4f}")
# 示例 2:RL(policy gradient)
# 模拟 model 生成的 token 和它们的 logprob
model_tokens = [7741, 34651, 31410]
model_logprobs = [-0.5, -0.3, -0.8]
advantages = [0.15] # 单条 trajectory,正 reward
model_input_rl = types.ModelInput.from_ints(model_tokens)
datum_rl = types.Datum(
model_input=model_input_rl,
loss_fn_inputs={
"target_tokens": model_tokens,
"logprobs": model_logprobs,
"advantages": advantages,
},
)
result_rl = training_client.forward_backward([datum_rl], loss_fn="ppo").result()
print(f"PPO loss: {result_rl.loss:.4f}")
# 走一步 optimizer
adam_params = types.AdamParams(learning_rate=1e-4)
training_client.optim_step(adam_params).result()完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/quickstart/custom_loss.py
API Surface
| Loss | loss_fn_inputs | 用途 |
|---|---|---|
"cross_entropy" | target_tokens、weights | SFT、多轮监督 |
"importance_sampling" | target_tokens、logprobs、advantages | off-policy RL、offline RL |
"ppo" | target_tokens、logprobs、advantages | on-policy RL、稳定的梯度更新 |
"cispo" | target_tokens、logprobs、advantages | 保守 RL、平滑裁剪 |
RL loss 的可选参数:
clip_ratio(PPO)—— 裁剪范围[1 - clip_ratio, 1 + clip_ratio]。默认0.2。范围0.0–0.5。beta(CISPO)—— 软裁剪锐度。默认1.0,越大裁剪越锐,越接近硬 PPO。
Caveats & Pitfalls
- 数据格式:
loss_fn_inputs的 key 必须严格匹配 loss function 的要求。传target_logprobs而漏传logprobs会抛 KeyError。 - Advantage 范围:RL 的 advantage 应该在每个 batch 内做零中心化。用 baseline 或 value function 估计 advantage,不要直接拿原始 reward。
- Importance sampling 漂移:
loss_fn="importance_sampling"对分布漂移敏感。如果 policy 偏离 behavior policy 太远,importance weight 会爆炸。改用loss_fn="ppo"(带裁剪)或缩小学习率可以缓解。 - SFT 的 weight mask:SFT 时 prompt token 的
weights=0,completion token 的weights=1。renderer 的build_supervised_example()会自动处理这一步。 - Off-policy vs on-policy:
importance_sampling接受任意 policy 产生的数据(真正的 off-policy)。ppo和cispo假设 trajectory 由当前 model 生成(on-policy 或近似 on-policy)。按数据来源选。