Mind Lab Toolkit (MinT)
Using the API

Loss Functions

MinT supports both built-in loss functions and client-side custom losses.

Built-in losses

LossBest forRequired inputs
cross_entropySFT / next-token predictiontarget_tokens, weights
importance_samplingSimple RL policy-gradient loopstarget_tokens, weights, logprobs, advantages
ppoClipped RL updatesSame as importance_sampling
cispoClipped IS variantSame as importance_sampling
droRegularized off-policy RLSame as importance_sampling plus beta config

Core formulas

# cross_entropy
loss = (-target_logprobs * weights).sum()

# importance_sampling
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
loss = -(prob_ratio * advantages).sum()

# ppo
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
loss = -torch.min(prob_ratio * advantages, clipped_ratio * advantages).sum()

# cispo
loss = -(clipped_ratio.detach() * target_logprobs * advantages).sum()

# dro
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
loss = -(target_logprobs * advantages - 0.5 * beta * quadratic_term).sum()

Custom loss via forward_backward_custom

Use forward_backward_custom when your objective cannot be expressed by the built-in losses.

Typical examples:

  • Pairwise preference / Bradley-Terry objectives
  • DPO-style sequence losses
  • Multi-sequence comparison objectives

Important notes:

  • The custom loss runs on the client, not on the server.
  • The current implementation is built from a forward pass plus a forward_backward pass under the hood.
  • Expect roughly 1.5x FLOPs and up to 3x wall time compared with a single built-in forward_backward.

Runnable pairwise preference example

The companion runnable example lives in mint-quickstart:

python quickstart/custom_loss.py

It trains on (chosen, rejected) pairs ordered consecutively in the batch and computes a sequence-level Bradley-Terry-style loss.

def sequence_logprob(logprobs: torch.Tensor, weights: Any) -> torch.Tensor:
    logprob_tensor = logprobs.flatten().float()
    weight_tensor = _to_float_tensor(weights)
    return torch.dot(logprob_tensor, weight_tensor)


def pairwise_preference_loss(
    data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
    if len(data) % 2 != 0:
        raise ValueError("Expected (chosen, rejected) pairs.")

    chosen_scores = []
    rejected_scores = []

    for chosen_datum, rejected_datum, chosen_logprobs, rejected_logprobs in zip(
        data[::2], data[1::2], logprobs_list[::2], logprobs_list[1::2]
    ):
        chosen_scores.append(
            sequence_logprob(chosen_logprobs, chosen_datum.loss_fn_inputs["weights"])
        )
        rejected_scores.append(
            sequence_logprob(rejected_logprobs, rejected_datum.loss_fn_inputs["weights"])
        )

    chosen_scores_tensor = torch.stack(chosen_scores)
    rejected_scores_tensor = torch.stack(rejected_scores)
    margins = chosen_scores_tensor - rejected_scores_tensor
    loss = -F.logsigmoid(margins).mean()
    metrics = {
        "loss": float(loss.detach().cpu()),
        "pair_accuracy": float((margins > 0).float().mean().detach().cpu()),
        "mean_margin": float(margins.mean().detach().cpu()),
    }
    return loss, metrics


result = training_client.forward_backward_custom(data, pairwise_preference_loss).result()
print(result.metrics)

Data contract for pairwise losses

  • Order the batch as (chosen, rejected, chosen, rejected, ...).
  • Keep prompt tokens masked out in weights; only completion tokens should contribute to the sequence score.
  • If you want a reference-policy term for full DPO, compute those reference logprobs in client code and include them inside your custom loss function.

On this page