DPO Overview
Direct Preference Optimization (DPO) trains a model to prefer chosen responses over rejected ones. Unlike SFT, which requires labeled correct answers, DPO works with pairwise preference signals — useful for tasks where multiple valid answers exist or when human judges score relative quality.
In MinT, DPO uses forward_backward_custom with a Bradley-Terry-style pairwise loss defined in client Python. There is no built-in loss_fn="dpo"; instead, you provide the loss closure directly.
Configuration
DPO uses the same ServiceClient and LoRA training setup as SFT, but calls forward_backward_custom instead of the standard forward_backward:
import mint
from mint import types
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-0.6B",
rank=16,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
tokenizer = training_client.get_tokenizer()
adam_params = types.AdamParams(learning_rate=1e-5) # Lower than SFTThe canonical DPO script is quickstart/custom_loss.py. It demonstrates:
- Building
PreferencePairobjects:(prompt, chosen, rejected). - Flattening pairs into
Datumobjects viaflatten_preference_pairs(...). - Defining a
pairwise_preference_loss(...)closure. - Calling
forward_backward_custom(data, loss_closure).
from quickstart.custom_loss import (
PreferencePair,
flatten_preference_pairs,
pairwise_preference_loss,
)
pairs = [
PreferencePair(
prompt="Explain backups.",
chosen="Backups reduce recovery time after failures...",
rejected="Backups are good.",
),
# ... more pairs
]Prompting Guide
DPO requires (chosen, rejected) pairs. You render both through your chat template and construct Datum objects with zero loss weight on the prompt and full weight on the response:
def build_preference_datum(
prompt_tokens: list[int], completion_text: str, tokenizer
) -> types.Datum:
completion_tokens = tokenizer.encode(f" {completion_text}", add_special_tokens=False)
completion_tokens.append(tokenizer.eos_token_id)
all_tokens = prompt_tokens + completion_tokens
input_tokens = all_tokens[:-1]
target_tokens = all_tokens[1:]
weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(completion_tokens)
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
)
def flatten_preference_pairs(pairs: list[PreferencePair], tokenizer) -> list[types.Datum]:
"""Convert (prompt, chosen, rejected) pairs into (chosen_datum, rejected_datum) pairs."""
data: list[types.Datum] = []
for pair in pairs:
prompt_tokens = build_prompt_tokens(pair.prompt, tokenizer)
data.append(build_preference_datum(prompt_tokens, pair.chosen, tokenizer))
data.append(build_preference_datum(prompt_tokens, pair.rejected, tokenizer))
return dataThe result is a flat list where each pair occupies two consecutive datums: [chosen₀, rejected₀, chosen₁, rejected₁, ...]. This structure is critical because pairwise_preference_loss expects even-indexed datums to be chosen and odd-indexed to be rejected.
Output Format
The pairwise preference loss compares the sequence log-probability of the chosen response to the rejected response. The loss uses a Bradley-Terry sigmoid:
def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
"""Weighted dot product of per-token logprobs."""
logprob_tensor = logprobs.flatten().float()
weight_tensor = _to_float_tensor(weights)
return torch.dot(logprob_tensor, weight_tensor)
def pairwise_preference_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
"""Bradley-Terry loss: -log(sigmoid(chosen_score - rejected_score))."""
chosen_scores = []
rejected_scores = []
for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
):
chosen_scores.append(
sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
)
rejected_scores.append(
sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
)
chosen_scores_tensor = torch.stack(chosen_scores)
rejected_scores_tensor = torch.stack(rejected_scores)
margins = chosen_scores_tensor - rejected_scores_tensor
loss = -F.logsigmoid(margins).mean() # Bradley-Terry: -log(sigmoid(margin))
metrics = {
"loss": float(loss.detach().cpu()),
"pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
"mean_margin": float(margins.mean().detach().cpu()),
}
return loss, metricsThe metrics reported are:
- loss: The average Bradley-Terry loss across pairs.
- pair_accuracy: Fraction of pairs where chosen_score > rejected_score.
- mean_margin: Average log-probability difference.
Convergence is indicated by pair_accuracy → 1.0 (all chosen > rejected) and loss → 0.
All Parameters
| Parameter | Type | Default | Meaning |
|---|---|---|---|
dpo_beta | float | 0.1 | Bradley-Terry temperature β. Higher β penalizes wrong preferences harder. Range: 0.05–0.5. |
learning_rate | float | 1e-5 | Adam learning rate. Lower than SFT because preference gradients are gentler. Typical: 5e-6 to 5e-5. |
betas | tuple[float, float] | (0.9, 0.999) | Adam exponential decay rates. |
eps | float | 1e-8 | Adam numerical stability term. |
weight_decay | float | 0.0 | L2 regularization. |
base_model | str | "Qwen/Qwen3-0.6B" | Base model ID. |
rank | int | 16 | LoRA rank. |
train_mlp | bool | True | Train MLP layers. |
train_attn | bool | True | Train attention layers. |
train_unembed | bool | True | Train output layer. |
max_length | int | 96 | Tokenizer max length for safety. |
batch_size | int | 8 | Pairs per batch. DPO is data-efficient; smaller batches (2–8) often work well. |
Usage:
for step in range(num_steps):
result = training_client.forward_backward_custom(
data, # list of (chosen_datum, rejected_datum) pairs flattened
pairwise_preference_loss
).result()
metrics = result.metrics or {}
training_client.optim_step(
types.AdamParams(learning_rate=1e-5)
).result()
print(f"Step {step}: loss={metrics.get('loss'):.4f}, "
f"pair_accuracy={metrics.get('pair_accuracy'):.1%}")Important: The data list must have an even length, with chosen and rejected datums interleaved: [chosen₀, rejected₀, chosen₁, rejected₁, ...]. If the order is wrong, the loss will compute silently incorrect margins.
What's next?
- RLHF 3-stage pipeline — combining DPO with other stages.
- Loss functions — other custom loss patterns.
- custom_loss.py — the canonical DPO reference script.