DemoMath RL

Math RL

This page documents demos/rl/adapters/verifiable_math.py in mint-quickstart.

What this demo does

  • Runs RL-1 Verifiable Math via the shared demos/rl/rl_core.py GRPO loop and a task-specific adapter.
  • Trains a LoRA on 2-digit addition; reward is deterministic exact-match grading (1.0 correct, 0.0 wrong).
  • Uses group sampling + group-relative advantages and updates via importance_sampling.

Expected output

  • Prints Model: ... and per-step Step N: accuracy=... + datums=...; finishes with Saved: ....

Common gotchas

  • Seeing datums=0 for many steps means all samples in a group got the same reward (advantages are zero); try larger MINT_RL_GROUP/MINT_RL_BATCH, higher MINT_RL_TEMPERATURE, or more steps.
  • If accuracy=0.0% persists, increase MINT_RL_MAX_TOKENS (default is intentionally small for speed).
  • This adapter uses regex integer extraction; if you need stricter grading, replace compute_reward() with task-specific verifiers.

Prerequisites

  • Python >= 3.11
  • MINT_API_KEY set (or configured via .env)

How to run

export MINT_API_KEY=sk-mint-...
python demos/rl/adapters/verifiable_math.py

Parameters (env vars)

  • MINT_BASE_MODEL: default Qwen/Qwen3-0.6B
  • MINT_LORA_RANK: default 16
  • MINT_RL_STEPS: default 5
  • MINT_RL_BATCH: default 10
  • MINT_RL_GROUP: default 4
  • MINT_RL_LR: default 1e-4
  • MINT_RL_MAX_TOKENS: default 8
  • MINT_RL_TEMPERATURE: default 1.0

Full script

#!/usr/bin/env python3
"""RL-1 Verifiable Math — adapter for rl_core.
 
Reward: exact-match integer grading (1.0 correct, 0.0 wrong).
Run:  python demos/rl/adapters/verifiable_math.py
"""
 
from __future__ import annotations
 
import os
import random
import re
import sys
from pathlib import Path
from typing import Any
 
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo  # noqa: E402
 
 
class VerifiableMathAdapter(RLAdapter):
    FEWSHOT = "Q: What is 4 + 5?\nA: 9\n\n"
 
    def build_dataset(self) -> list[tuple[str, int]]:
        return [
            (f"Q: What is {random.randint(0, 99)} + {random.randint(0, 99)}?\nA:", a + b)
            for a, b in [(random.randint(0, 99), random.randint(0, 99)) for _ in range(50)]
        ]
 
    def make_prompt(self, sample: tuple[str, int], tokenizer: Any) -> list[int]:
        question, _ = sample
        return tokenizer.encode(self.FEWSHOT + question)
 
    def compute_reward(self, response: str, sample: tuple[str, int]) -> float:
        _, answer = sample
        match = re.search(r"-?\d+", response)
        return 1.0 if match and int(match.group()) == answer else 0.0
 
    def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
        accuracy = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0
        print(f"Step {step}: accuracy={accuracy:.1%}, datums={num_datums}")
 
 
if __name__ == "__main__":
    cfg = RLConfig.from_env()
    cfg.steps = int(os.environ.get("MINT_RL_STEPS", "5"))
    cfg.batch = int(os.environ.get("MINT_RL_BATCH", "10"))
    cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "8"))
    run_grpo(VerifiableMathAdapter(), cfg)

Next steps

  • The final line prints Saved: <path>. You can load it in a new process:
import mint
 
service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<paste Saved path>")
  • For sampling + tokenization details, see /using-the-api/saving-and-loading and /api-reference/sampling-client.