Mind Lab Toolkit (MinT)
CustomizeConcepts

Loss Functions

MinT provides several built-in loss functions for the forward_backward() call, covering supervised learning and policy gradient training. Each loss computes gradients on the log-probabilities of your model's predicted tokens, weighted by importance, advantage, or custom reward signals.

Concept

Loss functions in MinT fall into two categories:

Supervised Learning (SFT & DPO)

  • loss_fn="cross_entropy" — Standard token prediction loss. Trains the model to match target tokens. Used in SFT and multi-turn datasets.

Policy Gradient (RL)

  • loss_fn="importance_sampling" — Off-policy RL. Reweights trajectories by importance_weight = target_logprob / behavior_logprob. Sensitive to distribution shift; best with small learning rates.
  • loss_fn="ppo" — On-policy RL. Clips the probability ratio to [1 - clip_ratio, 1 + clip_ratio] to prevent large updates. Stable, widely used.
  • loss_fn="cispo" — Conservative importance sampling. Soft clipping variant of IS; clamps logprob differences instead of probability ratios. Good middle ground.

Each RL loss accepts an advantages array (one advantage value per trajectory) and optional behavior_logprobs (log-probabilities under the behavior policy). The gradients optimize the expected return under the new policy.

Pattern

import mint
from mint import types
import torch

service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-0.6B",
    rank=16,
)
tokenizer = training_client.get_tokenizer()

# Example 1: SFT (supervised learning)
prompt_text = "The capital of France is"
target_text = " Paris."
prompt_ids = tokenizer.encode(prompt_text)
target_ids = tokenizer.encode(target_text)
all_ids = prompt_ids + target_ids

model_input_sft = types.ModelInput.from_ints(all_ids[:-1])
sft_target_tokens = all_ids[1:]
sft_weights = [0] * len(prompt_ids) + [1] * len(target_ids)

datum_sft = types.Datum(
    model_input=model_input_sft,
    loss_fn_inputs={
        "target_tokens": sft_target_tokens,
        "weights": sft_weights,
    },
)

result = training_client.forward_backward([datum_sft], loss_fn="cross_entropy").result()
print(f"SFT loss: {result.loss:.4f}")

# Example 2: RL (policy gradient)
# Simulate tokens generated by the model and their log-probabilities
model_tokens = [7741, 34651, 31410]
model_logprobs = [-0.5, -0.3, -0.8]
advantages = [0.15]  # Single trajectory, positive reward

model_input_rl = types.ModelInput.from_ints(model_tokens)

datum_rl = types.Datum(
    model_input=model_input_rl,
    loss_fn_inputs={
        "target_tokens": model_tokens,
        "logprobs": model_logprobs,
        "advantages": advantages,
    },
)

result_rl = training_client.forward_backward([datum_rl], loss_fn="ppo").result()
print(f"PPO loss: {result_rl.loss:.4f}")

# Apply optimizer step
adam_params = types.AdamParams(learning_rate=1e-4)
training_client.optim_step(adam_params).result()

View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/quickstart/custom_loss.py

API Surface

Lossloss_fn_inputsUse case
"cross_entropy"target_tokens, weightsSFT, multi-turn supervised
"importance_sampling"target_tokens, logprobs, advantagesOff-policy RL, offline RL
"ppo"target_tokens, logprobs, advantagesOn-policy RL, stable gradient updates
"cispo"target_tokens, logprobs, advantagesConservative RL, smooth clipping

Optional parameters for RL losses:

  • clip_ratio (PPO) — Clipping range [1 - clip_ratio, 1 + clip_ratio]. Default: 0.2. Range: 0.0–0.5.
  • beta (CISPO) — Soft clipping sharpness. Default: 1.0. Higher = sharper clipping, closer to hard PPO.

Caveats & Pitfalls

  • Data format: loss_fn_inputs keys must exactly match the loss function's requirements. Passing target_logprobs without logprobs will raise a KeyError.
  • Advantage range: RL advantages should be zero-centered within each batch. Use a baseline or value function to estimate advantages; do not use raw rewards directly.
  • Importance sampling drift: loss_fn="importance_sampling" is sensitive to distribution shift. If the policy diverges too far from the behavior policy, the importance weights explode. Mitigate with loss_fn="ppo" (clipping) or small learning rates.
  • Weight masking in SFT: For SFT, set weights=0 on prompt tokens and weights=1 on completion tokens. The renderer's build_supervised_example() does this automatically.
  • Off-policy vs on-policy: importance_sampling works with data from any policy (truly off-policy). ppo and cispo assume the model generated the trajectories (on-policy or near on-policy). Choose based on your data source.

On this page