Loss Functions
MinT provides several built-in loss functions optimized for different training scenarios.
Built-in Loss Functions
Cross-Entropy (Supervised Learning)
elementwise_loss = -target_logprobs * weights
loss = elementwise_loss.sum()Inputs: target_tokens, weights
Outputs: logprobs
Importance Sampling (Policy Gradient)
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
loss = -(prob_ratio * advantages).sum()PPO (Proximal Policy Optimization)
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
loss = -torch.min(unclipped_objective, clipped_objective).sum()Configuration: loss_fn_config={"clip_low_threshold": 0.9, "clip_high_threshold": 1.1}
CISPO (Clipped Importance Sampling Policy Optimization)
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
loss = -(clipped_ratio.detach() * target_logprobs * advantages).sum()DRO (Direct Reward Optimization)
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
loss = -(target_logprobs * advantages - 0.5 * beta * quadratic_term).sum()Configuration: loss_fn_config={"beta": 0.05}
Custom Loss Functions
For advanced use cases, you can define custom loss functions:
def custom_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict]:
loss = compute_your_loss(logprobs)
return loss, {"metric_name": loss.item()}
training_client.forward_backward_custom(data, custom_loss)Trade-off: Custom losses use 1.5x FLOPs and up to 3x wall time, but enable arbitrary differentiable objectives like Bradley-Terry or DPO.