DPO 概览
Direct Preference Optimization(DPO)训练 model,让它更偏好同一个 prompt 下的 chosen response,而不是 rejected response。
在 MinT 里,这个 recipe 使用低层 TrainingClient.forward_backward_custom() API。这个 recipe 没有内置 loss_fn="dpo"。Bradley-Terry loss 是一个普通 Python 函数,在客户端运行,并接收 MinT 返回的 model logprobs。
这个页面和 recipes/dpo_native.py 对齐。
数据形状
训练数据从 (prompt, chosen, rejected) 三元组开始:
@dataclass(frozen=True)
class PreferencePair:
prompt: str
chosen: str
rejected: str
pairs = [
PreferencePair(
prompt="Explain why regular backups matter.",
chosen="Backups protect data by creating copies that can be restored...",
rejected="Backups are good.",
),
]recipe 会把每个 pair 展平成两个 Datum:
[chosen₀, rejected₀, chosen₁, rejected₁, ...]
even odd even odd这个顺序是必须的。loss 假设偶数下标是 chosen,奇数下标是 rejected。
构造 Datum
prompt tokens 的 loss weight 是 0。completion tokens 的 weight 是 1.0:
def build_datum(prompt_tokens, completion_text, tokenizer):
completion_tokens = tokenizer.encode(f" {completion_text}", add_special_tokens=False)
completion_tokens.append(tokenizer.eos_token_id)
all_tokens = prompt_tokens + completion_tokens
input_tokens = all_tokens[:-1]
target_tokens = all_tokens[1:]
weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(completion_tokens)
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
)如果 tokenizer 支持 chat template,prompt 会用模型的 chat template:
def build_prompt_tokens(prompt, tokenizer):
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
)自定义 Bradley-Terry Loss
forward_backward_custom() 会先让 datums 过 model,然后用 (data, logprobs_list) 调你的 Python loss function。
核心 loss:
def sequence_logprob(logprobs, weights):
# 重要:logprobs 要保持 Tensor,这样 gradients 不会断。
logprob_tensor = logprobs.flatten().float()
weight_tensor = _to_float_tensor(weights)
return torch.dot(logprob_tensor, weight_tensor)
def pairwise_preference_loss(data, logprobs_list):
chosen_scores = []
rejected_scores = []
for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
):
chosen_scores.append(
sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
)
rejected_scores.append(
sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
)
margins = torch.stack(chosen_scores) - torch.stack(rejected_scores)
loss = -F.logsigmoid(margins).mean()
metrics = {
"loss": float(loss.detach().cpu()),
"pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
"mean_margin": float(margins.mean().detach().cpu()),
}
return loss, metrics不要在计算 loss 前把 logprobs 转成 Python list。那会让 tensor 脱离 autograd,导致 forward_backward_custom() 不能反传。
Training Loop
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-0.6B",
rank=16,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
tokenizer = training_client.get_tokenizer()
data = flatten_preference_pairs(PREFERENCE_PAIRS, tokenizer)
for step in range(1, DPO_STEPS + 1):
result = training_client.forward_backward_custom(
data,
pairwise_preference_loss,
).result()
metrics = result.metrics or {}
training_client.optim_step(types.AdamParams(learning_rate=1e-5)).result()
print(
f"Step {step}: loss={metrics['loss']:.6f}, "
f"pair_accuracy={metrics['pair_accuracy']:.2f}"
)完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/dpo_native.py
Verified Run
已在 MinT 上验证:Qwen/Qwen3-0.6B,4 个 preference pairs,3 个 DPO steps:
Step 1: loss=34.563499, pair_accuracy=0.00, mean_margin=-34.563488
Step 2: loss=34.331955, pair_accuracy=0.00, mean_margin=-34.331944
Step 3: loss=33.277603, pair_accuracy=0.00, mean_margin=-33.277576最终 checkpoint:
tinker://06770ead-184f-4638-824a-21138820dc4f_0/sampler_weights/dpo-native-final这个小样本数据只用于验证 API。pair_accuracy=0.00 是合法值,因为 base model 初始可能给 rejected completions 更高分。关键验证点是:custom loss 是 finite 的,gradient 能传,optimizer step 完成,并且 metrics 正常返回。
这个 Recipe 用到的参数
| 参数 | 默认值 | 含义 |
|---|---|---|
MINT_BASE_MODEL | Qwen/Qwen3-0.6B | 要训练的 base model。 |
MINT_LORA_RANK | 16 | LoRA rank。 |
MINT_DPO_STEPS | 3 | custom-loss training steps 数量。 |
MINT_DPO_LR | 1e-5 | Adam learning rate。 |