Mini RL Trip (Minimal RL Loop)
This page is a minimal end-to-end RL training loop using MinT:
- Task/env: multiplication (
Question → Answer, range 10-199) - Reward:
1.0if correct, else0.0 - Rollouts: sample multiple answers per question (
group_size) - Update: build RL
Datums and callforward_backward(..., loss_fn="importance_sampling")+optim_step()
Prerequisites
- Python >= 3.11
- A MinT API key in
MINT_API_KEY
Install (choose one way to install MinT):
pip install -q python-dotenv
# Recommended (GitHub):
pip install -q git+https://github.com/MindLab-Research/mindlab-toolkit.git
# Or if you already have it on pip:
pip install -q mindlab-toolkitParameters (env vars)
| Env Var | Default | API Type | Description |
|---|---|---|---|
MINT_BASE_MODEL | Qwen/Qwen3-0.6B | - | Base model name |
MINT_LORA_RANK | 8 | LoraConfig | LoRA rank (max: 64) |
MINT_RL_STEPS | 3 | - | Number of outer training steps |
MINT_RL_BATCH | 2 | - | Questions per step |
MINT_RL_GROUP | 4 | - | Sampled answers per question |
MINT_RL_MAX_TOKENS | 64 | SamplingParams | Max new tokens per answer |
MINT_RL_TEMPERATURE | 0.8 | SamplingParams | Sampling randomness |
MINT_RL_LR | 5e-5 | AdamParams | Learning rate |
MINT_SYSTEM_PROMPT_PATH | - | - | Path to system prompt file |
MINT_PRINT_TOKEN_COUNTS | false | - | Print token counts |
Full script
from __future__ import annotations
import os
import random
import re
import time
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
import mint
from mint import types
# Config via env vars
BASE_MODEL = os.environ.get("MINT_BASE_MODEL", "Qwen/Qwen3-0.6B")
LORA_RANK = int(os.environ.get("MINT_LORA_RANK", "8"))
RL_STEPS = int(os.environ.get("MINT_RL_STEPS", "3"))
BATCH_SIZE = int(os.environ.get("MINT_RL_BATCH", "2"))
GROUP_SIZE = int(os.environ.get("MINT_RL_GROUP", "4"))
MAX_TOKENS = int(os.environ.get("MINT_RL_MAX_TOKENS", "64"))
TEMPERATURE = float(os.environ.get("MINT_RL_TEMPERATURE", "0.8"))
LEARNING_RATE = float(os.environ.get("MINT_RL_LR", "5e-5"))
SYSTEM_PROMPT_PATH = os.environ.get("MINT_SYSTEM_PROMPT_PATH", "")
PRINT_TOKEN_COUNTS = os.environ.get("MINT_PRINT_TOKEN_COUNTS", "").lower() in ("1", "true")
# Load system prompt if provided
SYSTEM_PROMPT = ""
if SYSTEM_PROMPT_PATH and Path(SYSTEM_PROMPT_PATH).exists():
SYSTEM_PROMPT = Path(SYSTEM_PROMPT_PATH).read_text().strip()
def extract_answer(response: str) -> str | None:
match = re.search(r"-?\d+", response)
return match.group(0) if match else None
def generate_problem() -> tuple[str, str]:
a, b = random.randint(10, 199), random.randint(10, 199)
return f"What is {a} * {b}?", str(a * b)
def compute_reward(response: str, correct: str) -> float:
return 1.0 if extract_answer(response) == correct else 0.0
def build_prompt(question: str) -> str:
if SYSTEM_PROMPT:
return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\nAnswer:"
return f"Question: {question}\nAnswer:"
def validate_loss_inputs(inputs: dict) -> bool:
required = ["target_tokens", "weights", "logprobs", "advantages"]
if not all(k in inputs and len(inputs[k]) > 0 for k in required):
return False
lengths = [len(inputs[k]) for k in required]
return len(set(lengths)) == 1
def save_weights_with_retry(training_client, service_client, base_model: str, name: str, retries: int = 3):
for attempt in range(retries):
try:
path = training_client.save_weights_for_sampler(name=name).result().path
return service_client.create_sampling_client(model_path=path, base_model=base_model)
except Exception as e:
if attempt < retries - 1:
time.sleep(min(1.0 * (2 ** attempt), 10.0))
else:
raise RuntimeError(f"Failed after {retries} attempts: {e}") from e
def main() -> None:
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL, rank=LORA_RANK, train_mlp=True, train_attn=True, train_unembed=True
)
tokenizer = training_client.get_tokenizer()
print(f"Config: model={BASE_MODEL}, steps={RL_STEPS}, batch={BATCH_SIZE}, group={GROUP_SIZE}, lr={LEARNING_RATE}")
for step in range(RL_STEPS):
sampling_client = save_weights_with_retry(training_client, service_client, BASE_MODEL, f"rl-step-{step}")
training_datums: list[types.Datum] = []
all_rewards: list[float] = []
for _ in range(BATCH_SIZE):
question, answer = generate_problem()
prompt = build_prompt(question)
prompt_tokens = tokenizer.encode(prompt)
if PRINT_TOKEN_COUNTS:
print(f"[step {step+1}] prompt_tokens={len(prompt_tokens)}")
result = sampling_client.sample(
prompt=types.ModelInput.from_ints(tokens=prompt_tokens),
num_samples=GROUP_SIZE,
sampling_params=types.SamplingParams(
max_tokens=MAX_TOKENS, temperature=TEMPERATURE, stop_token_ids=[tokenizer.eos_token_id]
),
).result()
group_rewards, group_responses, group_logprobs = [], [], []
for seq in result.sequences:
response_text = tokenizer.decode(seq.tokens)
reward = compute_reward(response_text, answer)
group_rewards.append(reward)
group_responses.append(list(seq.tokens))
group_logprobs.append(list(seq.logprobs or [0.0] * len(seq.tokens)))
all_rewards.extend(group_rewards)
mean_reward = sum(group_rewards) / len(group_rewards)
advantages = [r - mean_reward for r in group_rewards]
if all(adv == 0.0 for adv in advantages):
continue
for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
if not response_tokens:
continue
full_tokens = prompt_tokens + response_tokens
input_tokens, target_tokens = full_tokens[:-1], full_tokens[1:]
prefix_len = len(prompt_tokens) - 1
loss_inputs = {
"target_tokens": [int(t) for t in target_tokens],
"weights": [0.0] * prefix_len + [1.0] * len(response_tokens),
"logprobs": [0.0] * prefix_len + [float(lp) for lp in logprobs],
"advantages": [0.0] * prefix_len + [adv] * len(response_tokens),
}
if validate_loss_inputs(loss_inputs):
training_datums.append(
types.Datum(model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=loss_inputs)
)
if training_datums:
training_client.forward_backward(training_datums, loss_fn="importance_sampling").result()
training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()
avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
print(f"Step {step+1}/{RL_STEPS}: accuracy={accuracy:.1%}, avg_reward={avg_reward:.3f}, datums={len(training_datums)}")
# Save checkpoint after each step (recommended for long training runs)
if training_datums:
step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
print(f" Checkpoint: {step_ckpt.path}")
checkpoint = training_client.save_state(name="mini-rl-final").result()
print(f"Saved checkpoint: {checkpoint.path}")
if __name__ == "__main__":
main()Checkpoints and resuming
For long training runs (especially with large models like 235B), save checkpoints after each step to recover from failures:
# Save checkpoint after each step
step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
print(f"Checkpoint: {step_ckpt.path}")
# To resume from a checkpoint:
training_client = service_client.create_training_client_from_state(checkpoint_path)Common gotchas
datums=0
If datums=0, it means all answers in a group got the same reward, so the advantages are all zero and the code skips the update. To get a learning signal, increase diversity:
- Increase
MINT_RL_GROUP - Increase
MINT_RL_TEMPERATURE - Use a less-binary reward (e.g., partial credit)
save_weights_for_sampler fails
The script includes retry logic with exponential backoff. If it still fails after 3 attempts, check:
- Server connectivity
- Disk space on the MinT server
Next: From toy loop to real tasks
When you’re ready to move beyond the minimal script, see: RL Best Practices