Mind Lab Toolkit (MinT)
使用指南

Loss 函数

MinT 同时支持内置 loss 和客户端自定义 loss。

内置 loss

Loss适合场景必需输入
cross_entropySFT / next-token predictiontarget_tokensweights
importance_sampling简单 RL policy-gradient 闭环target_tokensweightslogprobsadvantages
ppo带裁剪的 RL 更新importance_sampling 相同
cispoclipped IS 变体importance_sampling 相同
dro带正则的 off-policy RLimportance_sampling 相同,外加 beta 配置

核心公式

# cross_entropy
loss = (-target_logprobs * weights).sum()

# importance_sampling
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
loss = -(prob_ratio * advantages).sum()

# ppo
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
loss = -torch.min(prob_ratio * advantages, clipped_ratio * advantages).sum()

# cispo
loss = -(clipped_ratio.detach() * target_logprobs * advantages).sum()

# dro
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
loss = -(target_logprobs * advantages - 0.5 * beta * quadratic_term).sum()

forward_backward_custom 做自定义 loss

当你的目标函数无法用内置 loss 表达时,就用 forward_backward_custom

常见场景:

  • 成对偏好 / Bradley-Terry 目标
  • DPO-style 序列级 loss
  • 跨多条序列的比较目标

几个关键点:

  • 自定义 loss 在客户端执行,不是在服务端执行。
  • 当前实现底层是客户端先做一次 forward,再拼接到 forward_backward 链路里。
  • 相比一次普通 forward_backward,大约会增加 1.5x FLOPs,wall time 最多可到 3x

可运行的成对偏好示例

配套 runnable example 在 mint-quickstart

python quickstart/custom_loss.py

它把 batch 固定组织成连续的 (chosen, rejected) 对,并计算 sequence-level Bradley-Terry-style loss。

def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
    logprob_tensor = logprobs.flatten().float()
    weight_tensor = _to_float_tensor(weights)
    return torch.dot(logprob_tensor, weight_tensor)


def pairwise_preference_loss(
    data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
    if len(data) % 2 != 0:
        raise ValueError("Expected (chosen, rejected) pairs.")

    chosen_scores = []
    rejected_scores = []

    for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
        data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
    ):
        chosen_scores.append(
            sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
        )
        rejected_scores.append(
            sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
        )

    chosen_scores_tensor = torch.stack(chosen_scores)
    rejected_scores_tensor = torch.stack(rejected_scores)
    margins = chosen_scores_tensor - rejected_scores_tensor
    loss = -F.logsigmoid(margins).mean()
    metrics = {
        "loss": float(loss.detach().cpu()),
        "pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
        "mean_margin": float(margins.mean().detach().cpu()),
    }
    return loss, metrics


result = training_client.forward_backward_custom(data, pairwise_preference_loss).result()
print(result.metrics)

成对偏好 loss 的数据约定

  • batch 顺序固定为 (chosen, rejected, chosen, rejected, ...)
  • weights 只保留 completion token,prompt token 必须 mask 掉。
  • 如果你要做完整 DPO,需要在客户端额外计算 reference policy logprobs,再在 custom loss 函数里一起使用。

本页目录