Using the API
Loss Functions
MinT supports both built-in loss functions and client-side custom losses.
Built-in losses
| Loss | Best for | Required inputs |
|---|---|---|
cross_entropy | SFT / next-token prediction | target_tokens, weights |
importance_sampling | Simple RL policy-gradient loops | target_tokens, weights, logprobs, advantages |
ppo | Clipped RL updates | Same as importance_sampling |
cispo | Clipped IS variant | Same as importance_sampling |
dro | Regularized off-policy RL | Same as importance_sampling plus beta config |
Core formulas
# cross_entropy
loss = (-target_logprobs * weights).sum()
# importance_sampling
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
loss = -(prob_ratio * advantages).sum()
# ppo
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
loss = -torch.min(prob_ratio * advantages, clipped_ratio * advantages).sum()
# cispo
loss = -(clipped_ratio.detach() * target_logprobs * advantages).sum()
# dro
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
loss = -(target_logprobs * advantages - 0.5 * beta * quadratic_term).sum()Custom loss via forward_backward_custom
Use forward_backward_custom when your objective cannot be expressed by the built-in losses.
Typical examples:
- Pairwise preference / Bradley-Terry objectives
- DPO-style sequence losses
- Multi-sequence comparison objectives
Important notes:
- The custom loss runs on the client, not on the server.
- The current implementation is built from a
forwardpass plus aforward_backwardpass under the hood. - Expect roughly
1.5xFLOPs and up to3xwall time compared with a single built-inforward_backward.
Runnable pairwise preference example
The companion runnable example lives in mint-quickstart:
python quickstart/custom_loss.pyIt trains on (chosen, rejected) pairs ordered consecutively in the batch and computes a sequence-level Bradley-Terry-style loss.
def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
logprob_tensor = logprobs.flatten().float()
weight_tensor = _to_float_tensor(weights)
return torch.dot(logprob_tensor, weight_tensor)
def pairwise_preference_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
if len(data) % 2 != 0:
raise ValueError("Expected (chosen, rejected) pairs.")
chosen_scores = []
rejected_scores = []
for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
):
chosen_scores.append(
sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
)
rejected_scores.append(
sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
)
chosen_scores_tensor = torch.stack(chosen_scores)
rejected_scores_tensor = torch.stack(rejected_scores)
margins = chosen_scores_tensor - rejected_scores_tensor
loss = -F.logsigmoid(margins).mean()
metrics = {
"loss": float(loss.detach().cpu()),
"pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
"mean_margin": float(margins.mean().detach().cpu()),
}
return loss, metrics
result = training_client.forward_backward_custom(data, pairwise_preference_loss).result()
print(result.metrics)Data contract for pairwise losses
- Order the batch as
(chosen, rejected, chosen, rejected, ...). - Keep prompt tokens masked out in
weights; only completion tokens should contribute to the sequence score. - If you want a reference-policy term for full DPO, compute those reference logprobs in client code and include them inside your custom loss function.