Mind Lab Toolkit (MinT)
CustomizeDPO

DPO Overview

Direct Preference Optimization (DPO) trains a model to prefer chosen responses over rejected ones. Unlike SFT, which requires labeled correct answers, DPO works with pairwise preference signals — useful for tasks where multiple valid answers exist or when human judges score relative quality.

In MinT, DPO uses forward_backward_custom with a Bradley-Terry-style pairwise loss defined in client Python. There is no built-in loss_fn="dpo"; instead, you provide the loss closure directly.

Configuration

DPO uses the same ServiceClient and LoRA training setup as SFT, but calls forward_backward_custom instead of the standard forward_backward:

import mint
from mint import types

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

tokenizer = training_client.get_tokenizer()
adam_params = types.AdamParams(learning_rate=1e-5)  # Lower than SFT

The canonical DPO script is quickstart/custom_loss.py. It demonstrates:

  1. Building PreferencePair objects: (prompt, chosen, rejected).
  2. Flattening pairs into Datum objects via flatten_preference_pairs(...).
  3. Defining a pairwise_preference_loss(...) closure.
  4. Calling forward_backward_custom(data, loss_closure).
from quickstart.custom_loss import (
    PreferencePair,
    flatten_preference_pairs,
    pairwise_preference_loss,
)

pairs = [
    PreferencePair(
        prompt="Explain backups.",
        chosen="Backups reduce recovery time after failures...",
        rejected="Backups are good.",
    ),
    # ... more pairs
]

Prompting Guide

DPO requires (chosen, rejected) pairs. You render both through your chat template and construct Datum objects with zero loss weight on the prompt and full weight on the response:

def build_preference_datum(
    prompt_tokens: list[int], completion_text: str, tokenizer
) -> types.Datum:
    completion_tokens = tokenizer.encode(f" {completion_text}", add_special_tokens=False)
    completion_tokens.append(tokenizer.eos_token_id)

    all_tokens = prompt_tokens + completion_tokens
    input_tokens = all_tokens[:-1]
    target_tokens = all_tokens[1:]
    weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
    )


def flatten_preference_pairs(pairs: list[PreferencePair], tokenizer) -> list[types.Datum]:
    """Convert (prompt, chosen, rejected) pairs into (chosen_datum, rejected_datum) pairs."""
    data: list[types.Datum] = []
    for pair in pairs:
        prompt_tokens = build_prompt_tokens(pair.prompt, tokenizer)
        data.append(build_preference_datum(prompt_tokens, pair.chosen, tokenizer))
        data.append(build_preference_datum(prompt_tokens, pair.rejected, tokenizer))
    return data

The result is a flat list where each pair occupies two consecutive datums: [chosen₀, rejected₀, chosen₁, rejected₁, ...]. This structure is critical because pairwise_preference_loss expects even-indexed datums to be chosen and odd-indexed to be rejected.

Output Format

The pairwise preference loss compares the sequence log-probability of the chosen response to the rejected response. The loss uses a Bradley-Terry sigmoid:

def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
    """Weighted dot product of per-token logprobs."""
    logprob_tensor = logprobs.flatten().float()
    weight_tensor = _to_float_tensor(weights)
    return torch.dot(logprob_tensor, weight_tensor)


def pairwise_preference_loss(
    data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
    """Bradley-Terry loss: -log(sigmoid(chosen_score - rejected_score))."""
    chosen_scores = []
    rejected_scores = []

    for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
        data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
    ):
        chosen_scores.append(
            sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
        )
        rejected_scores.append(
            sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
        )

    chosen_scores_tensor = torch.stack(chosen_scores)
    rejected_scores_tensor = torch.stack(rejected_scores)
    margins = chosen_scores_tensor - rejected_scores_tensor
    loss = -F.logsigmoid(margins).mean()  # Bradley-Terry: -log(sigmoid(margin))

    metrics = {
        "loss": float(loss.detach().cpu()),
        "pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
        "mean_margin": float(margins.mean().detach().cpu()),
    }
    return loss, metrics

The metrics reported are:

  • loss: The average Bradley-Terry loss across pairs.
  • pair_accuracy: Fraction of pairs where chosen_score > rejected_score.
  • mean_margin: Average log-probability difference.

Convergence is indicated by pair_accuracy → 1.0 (all chosen > rejected) and loss → 0.

All Parameters

ParameterTypeDefaultMeaning
dpo_betafloat0.1Bradley-Terry temperature β. Higher β penalizes wrong preferences harder. Range: 0.05–0.5.
learning_ratefloat1e-5Adam learning rate. Lower than SFT because preference gradients are gentler. Typical: 5e-6 to 5e-5.
betastuple[float, float](0.9, 0.999)Adam exponential decay rates.
epsfloat1e-8Adam numerical stability term.
weight_decayfloat0.0L2 regularization.
base_modelstr"Qwen/Qwen3-0.6B"Base model ID.
rankint16LoRA rank.
train_mlpboolTrueTrain MLP layers.
train_attnboolTrueTrain attention layers.
train_unembedboolTrueTrain output layer.
max_lengthint96Tokenizer max length for safety.
batch_sizeint8Pairs per batch. DPO is data-efficient; smaller batches (2–8) often work well.

Usage:

for step in range(num_steps):
    result = training_client.forward_backward_custom(
        data,  # list of (chosen_datum, rejected_datum) pairs flattened
        pairwise_preference_loss
    ).result()

    metrics = result.metrics or {}
    training_client.optim_step(
        types.AdamParams(learning_rate=1e-5)
    ).result()

    print(f"Step {step}: loss={metrics.get('loss'):.4f}, "
          f"pair_accuracy={metrics.get('pair_accuracy'):.1%}")

Important: The data list must have an even length, with chosen and rejected datums interleaved: [chosen₀, rejected₀, chosen₁, rejected₁, ...]. If the order is wrong, the loss will compute silently incorrect margins.

What's next?

On this page