Using the APILoss Functions

Loss Functions

MinT provides several built-in loss functions optimized for different training scenarios.

Built-in Loss Functions

Cross-Entropy (Supervised Learning)

elementwise_loss = -target_logprobs * weights
loss = elementwise_loss.sum()

Inputs: target_tokens, weights Outputs: logprobs

Importance Sampling (Policy Gradient)

prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
loss = -(prob_ratio * advantages).sum()

PPO (Proximal Policy Optimization)

prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
loss = -torch.min(unclipped_objective, clipped_objective).sum()

Configuration: loss_fn_config={"clip_low_threshold": 0.9, "clip_high_threshold": 1.1}

CISPO (Clipped Importance Sampling Policy Optimization)

clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
loss = -(clipped_ratio.detach() * target_logprobs * advantages).sum()

DRO (Direct Reward Optimization)

quadratic_term = (target_logprobs - sampling_logprobs) ** 2
loss = -(target_logprobs * advantages - 0.5 * beta * quadratic_term).sum()

Configuration: loss_fn_config={"beta": 0.05}

Custom Loss Functions

For advanced use cases, you can define custom loss functions:

def custom_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict]:
    loss = compute_your_loss(logprobs)
    return loss, {"metric_name": loss.item()}
 
training_client.forward_backward_custom(data, custom_loss)

Trade-off: Custom losses use 1.5x FLOPs and up to 3x wall time, but enable arbitrary differentiable objectives like Bradley-Terry or DPO.