Mind Lab Toolkit (MinT)
CustomizeDPO

DPO 概览

Direct Preference Optimization(DPO)训练 model,让它更偏好同一个 prompt 下的 chosen response,而不是 rejected response。

在 MinT 里,这个 recipe 使用低层 TrainingClient.forward_backward_custom() API。这个 recipe 没有内置 loss_fn="dpo"。Bradley-Terry loss 是一个普通 Python 函数,在客户端运行,并接收 MinT 返回的 model logprobs。

这个页面和 recipes/dpo_native.py 对齐。

数据形状

训练数据从 (prompt, chosen, rejected) 三元组开始:

@dataclass(frozen=True)
class PreferencePair:
    prompt: str
    chosen: str
    rejected: str

pairs = [
    PreferencePair(
        prompt="Explain why regular backups matter.",
        chosen="Backups protect data by creating copies that can be restored...",
        rejected="Backups are good.",
    ),
]

recipe 会把每个 pair 展平成两个 Datum

[chosen₀, rejected₀, chosen₁, rejected₁, ...]
   even      odd       even      odd

这个顺序是必须的。loss 假设偶数下标是 chosen,奇数下标是 rejected。

构造 Datum

prompt tokens 的 loss weight 是 0。completion tokens 的 weight 是 1.0

def build_datum(prompt_tokens, completion_text, tokenizer):
    completion_tokens = tokenizer.encode(f" {completion_text}", add_special_tokens=False)
    completion_tokens.append(tokenizer.eos_token_id)

    all_tokens = prompt_tokens + completion_tokens
    input_tokens = all_tokens[:-1]
    target_tokens = all_tokens[1:]
    weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
    )

如果 tokenizer 支持 chat template,prompt 会用模型的 chat template:

def build_prompt_tokens(prompt, tokenizer):
    messages = [{"role": "user", "content": prompt}]
    return tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
    )

自定义 Bradley-Terry Loss

forward_backward_custom() 会先让 datums 过 model,然后用 (data, logprobs_list) 调你的 Python loss function。

核心 loss:

def sequence_logprob(logprobs, weights):
    # 重要:logprobs 要保持 Tensor,这样 gradients 不会断。
    logprob_tensor = logprobs.flatten().float()
    weight_tensor = _to_float_tensor(weights)
    return torch.dot(logprob_tensor, weight_tensor)


def pairwise_preference_loss(data, logprobs_list):
    chosen_scores = []
    rejected_scores = []

    for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
        data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
    ):
        chosen_scores.append(
            sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
        )
        rejected_scores.append(
            sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
        )

    margins = torch.stack(chosen_scores) - torch.stack(rejected_scores)
    loss = -F.logsigmoid(margins).mean()
    metrics = {
        "loss": float(loss.detach().cpu()),
        "pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
        "mean_margin": float(margins.mean().detach().cpu()),
    }
    return loss, metrics

不要在计算 loss 前把 logprobs 转成 Python list。那会让 tensor 脱离 autograd,导致 forward_backward_custom() 不能反传。

Training Loop

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

tokenizer = training_client.get_tokenizer()
data = flatten_preference_pairs(PREFERENCE_PAIRS, tokenizer)

for step in range(1, DPO_STEPS + 1):
    result = training_client.forward_backward_custom(
        data,
        pairwise_preference_loss,
    ).result()
    metrics = result.metrics or {}
    training_client.optim_step(types.AdamParams(learning_rate=1e-5)).result()
    print(
        f"Step {step}: loss={metrics['loss']:.6f}, "
        f"pair_accuracy={metrics['pair_accuracy']:.2f}"
    )

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/dpo_native.py

Verified Run

已在 MinT 上验证:Qwen/Qwen3-0.6B,4 个 preference pairs,3 个 DPO steps:

Step 1: loss=34.563499, pair_accuracy=0.00, mean_margin=-34.563488
Step 2: loss=34.331955, pair_accuracy=0.00, mean_margin=-34.331944
Step 3: loss=33.277603, pair_accuracy=0.00, mean_margin=-33.277576

最终 checkpoint:

tinker://06770ead-184f-4638-824a-21138820dc4f_0/sampler_weights/dpo-native-final

这个小样本数据只用于验证 API。pair_accuracy=0.00 是合法值,因为 base model 初始可能给 rejected completions 更高分。关键验证点是:custom loss 是 finite 的,gradient 能传,optimizer step 完成,并且 metrics 正常返回。

这个 Recipe 用到的参数

参数默认值含义
MINT_BASE_MODELQwen/Qwen3-0.6B要训练的 base model。
MINT_LORA_RANK16LoRA rank。
MINT_DPO_STEPS3custom-loss training steps 数量。
MINT_DPO_LR1e-5Adam learning rate。

What's next?

本页目录