Mind Lab Toolkit (MinT)
CustomizeDPO

DPO 概览

Direct Preference Optimization(DPO)训练 model 让它更倾向 chosen response 而不是 rejected response。和 SFT 需要标注好的正确答案不同,DPO 用的是成对偏好信号 —— 适合多个有效答案并存的任务,或者人工评判员只给出相对质量打分的场景。

MinT 里的 DPO 用 forward_backward_custom 配上一个在客户端 Python 里定义的 Bradley-Terry 风格成对 loss。没有内置 loss_fn="dpo",需要直接传 loss 闭包。

Configuration

DPO 用和 SFT 一样的 ServiceClient 和 LoRA training 配置,只是调 forward_backward_custom 而不是标准的 forward_backward

import mint
from mint import types

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

tokenizer = training_client.get_tokenizer()
adam_params = types.AdamParams(learning_rate=1e-5)  # 比 SFT 低

参考脚本是 quickstart/custom_loss.py,它演示了:

  1. 构造 PreferencePair 对象:(prompt, chosen, rejected)
  2. 通过 flatten_preference_pairs(...) 把成对数据展平成 Datum 列表。
  3. 定义 pairwise_preference_loss(...) 闭包。
  4. forward_backward_custom(data, loss_closure)
from quickstart.custom_loss import (
    PreferencePair,
    flatten_preference_pairs,
    pairwise_preference_loss,
)

pairs = [
    PreferencePair(
        prompt="Explain backups.",
        chosen="Backups reduce recovery time after failures...",
        rejected="Backups are good.",
    ),
    # ... 更多 pair
]

Prompting Guide

DPO 需要 (chosen, rejected) 对。两边都过 chat template 渲染,再构造 Datum:prompt 部分 loss weight 设 0,response 部分设 1.0:

def build_preference_datum(
    prompt_tokens: list[int], completion_text: str, tokenizer
) -> types.Datum:
    completion_tokens = tokenizer.encode(f" {completion_text}", add_special_tokens=False)
    completion_tokens.append(tokenizer.eos_token_id)

    all_tokens = prompt_tokens + completion_tokens
    input_tokens = all_tokens[:-1]
    target_tokens = all_tokens[1:]
    weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
    )


def flatten_preference_pairs(pairs: list[PreferencePair], tokenizer) -> list[types.Datum]:
    """把 (prompt, chosen, rejected) 三元组展平成 (chosen_datum, rejected_datum) 序列。"""
    data: list[types.Datum] = []
    for pair in pairs:
        prompt_tokens = build_prompt_tokens(pair.prompt, tokenizer)
        data.append(build_preference_datum(prompt_tokens, pair.chosen, tokenizer))
        data.append(build_preference_datum(prompt_tokens, pair.rejected, tokenizer))
    return data

返回的是一个扁平 list,每对 pair 占两个连续位置:[chosen₀, rejected₀, chosen₁, rejected₁, ...]。这个排列很关键,因为 pairwise_preference_loss 假设偶数下标是 chosen,奇数下标是 rejected。

Output Format

成对偏好 loss 比较 chosen response 的序列 logprob 和 rejected response 的序列 logprob,套一个 Bradley-Terry sigmoid:

def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
    """对 per-token logprob 做加权点积。"""
    logprob_tensor = logprobs.flatten().float()
    weight_tensor = _to_float_tensor(weights)
    return torch.dot(logprob_tensor, weight_tensor)


def pairwise_preference_loss(
    data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
    """Bradley-Terry loss:-log(sigmoid(chosen_score - rejected_score))。"""
    chosen_scores = []
    rejected_scores = []

    for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
        data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
    ):
        chosen_scores.append(
            sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
        )
        rejected_scores.append(
            sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
        )

    chosen_scores_tensor = torch.stack(chosen_scores)
    rejected_scores_tensor = torch.stack(rejected_scores)
    margins = chosen_scores_tensor - rejected_scores_tensor
    loss = -F.logsigmoid(margins).mean()  # Bradley-Terry:-log(sigmoid(margin))

    metrics = {
        "loss": float(loss.detach().cpu()),
        "pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
        "mean_margin": float(margins.mean().detach().cpu()),
    }
    return loss, metrics

返回的指标:

  • loss:所有 pair 上的平均 Bradley-Terry loss。
  • pair_accuracy:chosen_score > rejected_score 的 pair 比例。
  • mean_margin:平均 logprob 差。

收敛信号是 pair_accuracy → 1.0(所有 chosen > rejected)且 loss → 0

All Parameters

参数类型默认值含义
dpo_betafloat0.1Bradley-Terry 温度 β。β 越大,对错误偏好的惩罚越重。范围 0.05–0.5。
learning_ratefloat1e-5Adam 学习率。比 SFT 低,因为偏好梯度更温和。典型值 5e-6 到 5e-5。
betastuple[float, float](0.9, 0.999)Adam 一阶 / 二阶矩的指数衰减率。
epsfloat1e-8Adam 数值稳定项。
weight_decayfloat0.0L2 正则化。
base_modelstr"Qwen/Qwen3-0.6B"base model ID。
rankint16LoRA rank。
train_mlpboolTrue训练 MLP 层。
train_attnboolTrue训练 attention 层。
train_unembedboolTrue训练输出层。
max_lengthint96tokenizer 最大长度,做安全截断。
batch_sizeint8每个 batch 的 pair 数。DPO 数据效率高,小 batch(2–8)一般就够用。

用法:

for step in range(num_steps):
    result = training_client.forward_backward_custom(
        data,  # 已展平的 (chosen_datum, rejected_datum) 序列
        pairwise_preference_loss
    ).result()

    metrics = result.metrics or {}
    training_client.optim_step(
        types.AdamParams(learning_rate=1e-5)
    ).result()

    print(f"Step {step}: loss={metrics.get('loss'):.4f}, "
          f"pair_accuracy={metrics.get('pair_accuracy'):.1%}")

重要: 数据 list 必须是偶数长度,chosen 和 rejected 交错排:[chosen₀, rejected₀, chosen₁, rejected₁, ...]。顺序错了,loss 不会报错,但 margin 算出来是错的。

What's next?

本页目录