DemoMini RL Trip

Mini RL Trip (Minimal RL Loop)

This page is a minimal end-to-end RL training loop using MinT:

  • Task/env: multiplication (Question → Answer, range 10-199)
  • Reward: 1.0 if correct, else 0.0
  • Rollouts: sample multiple answers per question (group_size)
  • Update: build RL Datums and call forward_backward(..., loss_fn="importance_sampling") + optim_step()

Prerequisites

  • Python >= 3.11
  • A MinT API key in MINT_API_KEY

Install (choose one way to install MinT):

pip install -q python-dotenv
# Recommended (GitHub):
pip install -q git+https://github.com/MindLab-Research/mindlab-toolkit.git
# Or if you already have it on pip:
pip install -q mindlab-toolkit

Parameters (env vars)

Env VarDefaultAPI TypeDescription
MINT_BASE_MODELQwen/Qwen3-0.6B-Base model name
MINT_LORA_RANK8LoraConfigLoRA rank (max: 64)
MINT_RL_STEPS3-Number of outer training steps
MINT_RL_BATCH2-Questions per step
MINT_RL_GROUP4-Sampled answers per question
MINT_RL_MAX_TOKENS64SamplingParamsMax new tokens per answer
MINT_RL_TEMPERATURE0.8SamplingParamsSampling randomness
MINT_RL_LR5e-5AdamParamsLearning rate
MINT_SYSTEM_PROMPT_PATH--Path to system prompt file
MINT_PRINT_TOKEN_COUNTSfalse-Print token counts

Full script

from __future__ import annotations
 
import os
import random
import re
import time
from pathlib import Path
 
from dotenv import load_dotenv
 
load_dotenv()
 
import mint
from mint import types
 
# Config via env vars
BASE_MODEL = os.environ.get("MINT_BASE_MODEL", "Qwen/Qwen3-0.6B")
LORA_RANK = int(os.environ.get("MINT_LORA_RANK", "8"))
RL_STEPS = int(os.environ.get("MINT_RL_STEPS", "3"))
BATCH_SIZE = int(os.environ.get("MINT_RL_BATCH", "2"))
GROUP_SIZE = int(os.environ.get("MINT_RL_GROUP", "4"))
MAX_TOKENS = int(os.environ.get("MINT_RL_MAX_TOKENS", "64"))
TEMPERATURE = float(os.environ.get("MINT_RL_TEMPERATURE", "0.8"))
LEARNING_RATE = float(os.environ.get("MINT_RL_LR", "5e-5"))
SYSTEM_PROMPT_PATH = os.environ.get("MINT_SYSTEM_PROMPT_PATH", "")
PRINT_TOKEN_COUNTS = os.environ.get("MINT_PRINT_TOKEN_COUNTS", "").lower() in ("1", "true")
 
# Load system prompt if provided
SYSTEM_PROMPT = ""
if SYSTEM_PROMPT_PATH and Path(SYSTEM_PROMPT_PATH).exists():
    SYSTEM_PROMPT = Path(SYSTEM_PROMPT_PATH).read_text().strip()
 
 
def extract_answer(response: str) -> str | None:
    match = re.search(r"-?\d+", response)
    return match.group(0) if match else None
 
 
def generate_problem() -> tuple[str, str]:
    a, b = random.randint(10, 199), random.randint(10, 199)
    return f"What is {a} * {b}?", str(a * b)
 
 
def compute_reward(response: str, correct: str) -> float:
    return 1.0 if extract_answer(response) == correct else 0.0
 
 
def build_prompt(question: str) -> str:
    if SYSTEM_PROMPT:
        return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\nAnswer:"
    return f"Question: {question}\nAnswer:"
 
 
def validate_loss_inputs(inputs: dict) -> bool:
    required = ["target_tokens", "weights", "logprobs", "advantages"]
    if not all(k in inputs and len(inputs[k]) > 0 for k in required):
        return False
    lengths = [len(inputs[k]) for k in required]
    return len(set(lengths)) == 1
 
 
def save_weights_with_retry(training_client, service_client, base_model: str, name: str, retries: int = 3):
    for attempt in range(retries):
        try:
            path = training_client.save_weights_for_sampler(name=name).result().path
            return service_client.create_sampling_client(model_path=path, base_model=base_model)
        except Exception as e:
            if attempt < retries - 1:
                time.sleep(min(1.0 * (2 ** attempt), 10.0))
            else:
                raise RuntimeError(f"Failed after {retries} attempts: {e}") from e
 
 
def main() -> None:
    service_client = mint.ServiceClient()
    training_client = service_client.create_lora_training_client(
        base_model=BASE_MODEL, rank=LORA_RANK, train_mlp=True, train_attn=True, train_unembed=True
    )
    tokenizer = training_client.get_tokenizer()
 
    print(f"Config: model={BASE_MODEL}, steps={RL_STEPS}, batch={BATCH_SIZE}, group={GROUP_SIZE}, lr={LEARNING_RATE}")
 
    for step in range(RL_STEPS):
        sampling_client = save_weights_with_retry(training_client, service_client, BASE_MODEL, f"rl-step-{step}")
        training_datums: list[types.Datum] = []
        all_rewards: list[float] = []
 
        for _ in range(BATCH_SIZE):
            question, answer = generate_problem()
            prompt = build_prompt(question)
            prompt_tokens = tokenizer.encode(prompt)
 
            if PRINT_TOKEN_COUNTS:
                print(f"[step {step+1}] prompt_tokens={len(prompt_tokens)}")
 
            result = sampling_client.sample(
                prompt=types.ModelInput.from_ints(tokens=prompt_tokens),
                num_samples=GROUP_SIZE,
                sampling_params=types.SamplingParams(
                    max_tokens=MAX_TOKENS, temperature=TEMPERATURE, stop_token_ids=[tokenizer.eos_token_id]
                ),
            ).result()
 
            group_rewards, group_responses, group_logprobs = [], [], []
            for seq in result.sequences:
                response_text = tokenizer.decode(seq.tokens)
                reward = compute_reward(response_text, answer)
                group_rewards.append(reward)
                group_responses.append(list(seq.tokens))
                group_logprobs.append(list(seq.logprobs or [0.0] * len(seq.tokens)))
 
            all_rewards.extend(group_rewards)
            mean_reward = sum(group_rewards) / len(group_rewards)
            advantages = [r - mean_reward for r in group_rewards]
 
            if all(adv == 0.0 for adv in advantages):
                continue
 
            for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
                if not response_tokens:
                    continue
                full_tokens = prompt_tokens + response_tokens
                input_tokens, target_tokens = full_tokens[:-1], full_tokens[1:]
                prefix_len = len(prompt_tokens) - 1
 
                loss_inputs = {
                    "target_tokens": [int(t) for t in target_tokens],
                    "weights": [0.0] * prefix_len + [1.0] * len(response_tokens),
                    "logprobs": [0.0] * prefix_len + [float(lp) for lp in logprobs],
                    "advantages": [0.0] * prefix_len + [adv] * len(response_tokens),
                }
 
                if validate_loss_inputs(loss_inputs):
                    training_datums.append(
                        types.Datum(model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=loss_inputs)
                    )
 
        if training_datums:
            training_client.forward_backward(training_datums, loss_fn="importance_sampling").result()
            training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()
 
        avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
        accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
        print(f"Step {step+1}/{RL_STEPS}: accuracy={accuracy:.1%}, avg_reward={avg_reward:.3f}, datums={len(training_datums)}")
 
        # Save checkpoint after each step (recommended for long training runs)
        if training_datums:
            step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
            print(f"  Checkpoint: {step_ckpt.path}")
 
    checkpoint = training_client.save_state(name="mini-rl-final").result()
    print(f"Saved checkpoint: {checkpoint.path}")
 
 
if __name__ == "__main__":
    main()

Checkpoints and resuming

For long training runs (especially with large models like 235B), save checkpoints after each step to recover from failures:

# Save checkpoint after each step
step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
print(f"Checkpoint: {step_ckpt.path}")
 
# To resume from a checkpoint:
training_client = service_client.create_training_client_from_state(checkpoint_path)

Common gotchas

datums=0

If datums=0, it means all answers in a group got the same reward, so the advantages are all zero and the code skips the update. To get a learning signal, increase diversity:

  • Increase MINT_RL_GROUP
  • Increase MINT_RL_TEMPERATURE
  • Use a less-binary reward (e.g., partial credit)

save_weights_for_sampler fails

The script includes retry logic with exponential backoff. If it still fails after 3 attempts, check:

  • Server connectivity
  • Disk space on the MinT server

Next: From toy loop to real tasks

When you’re ready to move beyond the minimal script, see: RL Best Practices