CustomizeConcepts
Loss Functions
MinT provides several built-in loss functions for the forward_backward() call, covering supervised learning and policy gradient training. Each loss computes gradients on the log-probabilities of your model's predicted tokens, weighted by importance, advantage, or custom reward signals.
Concept
Loss functions in MinT fall into two categories:
Supervised Learning (SFT & DPO)
loss_fn="cross_entropy"— Standard token prediction loss. Trains the model to match target tokens. Used in SFT and multi-turn datasets.
Policy Gradient (RL)
loss_fn="importance_sampling"— Off-policy RL. Reweights trajectories byimportance_weight = target_logprob / behavior_logprob. Sensitive to distribution shift; best with small learning rates.loss_fn="ppo"— On-policy RL. Clips the probability ratio to[1 - clip_ratio, 1 + clip_ratio]to prevent large updates. Stable, widely used.loss_fn="cispo"— Conservative importance sampling. Soft clipping variant of IS; clamps logprob differences instead of probability ratios. Good middle ground.
Each RL loss accepts an advantages array (one advantage value per trajectory) and optional behavior_logprobs (log-probabilities under the behavior policy). The gradients optimize the expected return under the new policy.
Pattern
import mint
from mint import types
import torch
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-0.6B",
rank=16,
)
tokenizer = training_client.get_tokenizer()
# Example 1: SFT (supervised learning)
prompt_text = "The capital of France is"
target_text = " Paris."
prompt_ids = tokenizer.encode(prompt_text)
target_ids = tokenizer.encode(target_text)
all_ids = prompt_ids + target_ids
model_input_sft = types.ModelInput.from_ints(all_ids[:-1])
sft_target_tokens = all_ids[1:]
sft_weights = [0] * len(prompt_ids) + [1] * len(target_ids)
datum_sft = types.Datum(
model_input=model_input_sft,
loss_fn_inputs={
"target_tokens": sft_target_tokens,
"weights": sft_weights,
},
)
result = training_client.forward_backward([datum_sft], loss_fn="cross_entropy").result()
print(f"SFT loss: {result.loss:.4f}")
# Example 2: RL (policy gradient)
# Simulate tokens generated by the model and their log-probabilities
model_tokens = [7741, 34651, 31410]
model_logprobs = [-0.5, -0.3, -0.8]
advantages = [0.15] # Single trajectory, positive reward
model_input_rl = types.ModelInput.from_ints(model_tokens)
datum_rl = types.Datum(
model_input=model_input_rl,
loss_fn_inputs={
"target_tokens": model_tokens,
"logprobs": model_logprobs,
"advantages": advantages,
},
)
result_rl = training_client.forward_backward([datum_rl], loss_fn="ppo").result()
print(f"PPO loss: {result_rl.loss:.4f}")
# Apply optimizer step
adam_params = types.AdamParams(learning_rate=1e-4)
training_client.optim_step(adam_params).result()View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/quickstart/custom_loss.py
API Surface
| Loss | loss_fn_inputs | Use case |
|---|---|---|
"cross_entropy" | target_tokens, weights | SFT, multi-turn supervised |
"importance_sampling" | target_tokens, logprobs, advantages | Off-policy RL, offline RL |
"ppo" | target_tokens, logprobs, advantages | On-policy RL, stable gradient updates |
"cispo" | target_tokens, logprobs, advantages | Conservative RL, smooth clipping |
Optional parameters for RL losses:
clip_ratio(PPO) — Clipping range[1 - clip_ratio, 1 + clip_ratio]. Default:0.2. Range:0.0–0.5.beta(CISPO) — Soft clipping sharpness. Default:1.0. Higher = sharper clipping, closer to hard PPO.
Caveats & Pitfalls
- Data format:
loss_fn_inputskeys must exactly match the loss function's requirements. Passingtarget_logprobswithoutlogprobswill raise a KeyError. - Advantage range: RL advantages should be zero-centered within each batch. Use a baseline or value function to estimate advantages; do not use raw rewards directly.
- Importance sampling drift:
loss_fn="importance_sampling"is sensitive to distribution shift. If the policy diverges too far from the behavior policy, the importance weights explode. Mitigate withloss_fn="ppo"(clipping) or small learning rates. - Weight masking in SFT: For SFT, set
weights=0on prompt tokens andweights=1on completion tokens. The renderer'sbuild_supervised_example()does this automatically. - Off-policy vs on-policy:
importance_samplingworks with data from any policy (truly off-policy).ppoandcispoassume the model generated the trajectories (on-policy or near on-policy). Choose based on your data source.