使用指南
Loss 函数
MinT 同时支持内置 loss 和客户端自定义 loss。
内置 loss
| Loss | 适合场景 | 必需输入 |
|---|---|---|
cross_entropy | SFT / next-token prediction | target_tokens、weights |
importance_sampling | 简单 RL policy-gradient 闭环 | target_tokens、weights、logprobs、advantages |
ppo | 带裁剪的 RL 更新 | 与 importance_sampling 相同 |
cispo | clipped IS 变体 | 与 importance_sampling 相同 |
dro | 带正则的 off-policy RL | 与 importance_sampling 相同,外加 beta 配置 |
核心公式
# cross_entropy
loss = (-target_logprobs * weights).sum()
# importance_sampling
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
loss = -(prob_ratio * advantages).sum()
# ppo
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
loss = -torch.min(prob_ratio * advantages, clipped_ratio * advantages).sum()
# cispo
loss = -(clipped_ratio.detach() * target_logprobs * advantages).sum()
# dro
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
loss = -(target_logprobs * advantages - 0.5 * beta * quadratic_term).sum()用 forward_backward_custom 做自定义 loss
当你的目标函数无法用内置 loss 表达时,就用 forward_backward_custom。
常见场景:
- 成对偏好 / Bradley-Terry 目标
- DPO-style 序列级 loss
- 跨多条序列的比较目标
几个关键点:
- 自定义 loss 在客户端执行,不是在服务端执行。
- 当前实现底层是客户端先做一次
forward,再拼接到forward_backward链路里。 - 相比一次普通
forward_backward,大约会增加1.5xFLOPs,wall time 最多可到3x。
可运行的成对偏好示例
配套 runnable example 在 mint-quickstart:
python quickstart/custom_loss.py它把 batch 固定组织成连续的 (chosen, rejected) 对,并计算 sequence-level Bradley-Terry-style loss。
def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
logprob_tensor = logprobs.flatten().float()
weight_tensor = _to_float_tensor(weights)
return torch.dot(logprob_tensor, weight_tensor)
def pairwise_preference_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
if len(data) % 2 != 0:
raise ValueError("Expected (chosen, rejected) pairs.")
chosen_scores = []
rejected_scores = []
for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
):
chosen_scores.append(
sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
)
rejected_scores.append(
sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
)
chosen_scores_tensor = torch.stack(chosen_scores)
rejected_scores_tensor = torch.stack(rejected_scores)
margins = chosen_scores_tensor - rejected_scores_tensor
loss = -F.logsigmoid(margins).mean()
metrics = {
"loss": float(loss.detach().cpu()),
"pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
"mean_margin": float(margins.mean().detach().cpu()),
}
return loss, metrics
result = training_client.forward_backward_custom(data, pairwise_preference_loss).result()
print(result.metrics)成对偏好 loss 的数据约定
- batch 顺序固定为
(chosen, rejected, chosen, rejected, ...)。 weights只保留 completion token,prompt token 必须 mask 掉。- 如果你要做完整 DPO,需要在客户端额外计算 reference policy logprobs,再在 custom loss 函数里一起使用。