DemoMath RL

Math Reinforcement Learning

This tutorial demonstrates reinforcement learning (RL) training for math problems using MinT.

What You’ll Learn

  1. Configure different math environments (arithmetic to competition-level)
  2. Implement grading logic with sympy equivalence checking
  3. Run the RL training loop with group-based advantages
  4. Evaluate and visualize model improvement

Supported Environments

EnvironmentDatasetProblemsDifficultyUse Case
arithmeticGeneratedX + YEasyDebug pipeline
gsm8kOpenAI GSM8KGrade school word problemsMediumBaseline training
mathHendrycks MATHCompetition mathHardAdvanced reasoning
polarisPOLARIS-53KDifficulty-calibratedMixedOptimal training
deepmathDeepMath-103KLarge-scale verifiableMixedProduction training

RL vs SFT

MethodTraining SignalAdvantage
SFTMemorize exact answer formatSimple, fast
RLLearn from correctness signalHandles equivalent answers

RL excels at math because equivalent answers (1/2, 0.5, \frac{1}{2}) all receive positive reward.


Dataset Details

GSM8K (Grade School Math 8K)

Source: OpenAI (2021) | Size: 7.5K train + 1K test

Multi-step grade school word problems requiring 2-8 calculation steps. Only basic arithmetic needed.

Q: Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes 4 into muffins.
   She sells the rest at $2 each. How much does she make daily?
A: 16 - 3 - 4 = 9 eggs sold. 9 * 2 = $18. #### 18

Hendrycks MATH

Source: UC Berkeley (NeurIPS 2021) | Size: 12K train + 500 test

High school competition problems (AMC 10/12, AIME) across 7 subjects with LaTeX solutions.

Q: Find the sum of all positive integers n such that n^2 + 12n - 2007 is a perfect square.
A: \boxed{80}

DeepMath-103K

Source: Tencent + SJTU (2025) | Size: 103K

Large-scale dataset for RL training with verifiable answers. Each problem has 3 different solutions.

POLARIS-53K

Source: HKU NLP + ByteDance (2025) | Size: 53K

Difficulty-calibrated filtering based on target model capability. Removes problems the model already solves perfectly to create optimal training distribution.

StageDatasetPurpose
1. DebugarithmeticVerify RL pipeline works
2. Baselinegsm8kSimple math, quick results
3. Productionmath / deepmath / polarisTrain advanced reasoning

Step 0: Setup

Install required packages:

pip install -q datasets sympy pylatexenc mint

Load your API key:

import os
from dotenv import load_dotenv
 
load_dotenv()
 
if os.environ.get('MINT_API_KEY'):
    print("API key loaded")
else:
    print("WARNING: MINT_API_KEY not found!")

Connect to MinT:

import mint
from mint import types
 
service_client = mint.ServiceClient()
print("Connected to MinT")

Step 1: Configuration

Configure your training run:

# ========== CONFIGURATION ==========
 
# Environment: "arithmetic", "gsm8k", "math", "polaris", "deepmath"
ENV = "arithmetic"
 
# Model
BASE_MODEL = "Qwen/Qwen3-0.6B"
LORA_RANK = 16
 
# Training
NUM_STEPS = 100
BATCH_SIZE = 100 if ENV == "arithmetic" else 32
GROUP_SIZE = 4
LEARNING_RATE = 1e-4 if ENV == "arithmetic" else 1e-5
 
# Generation
MAX_TOKENS = 8 if ENV == "arithmetic" else 512
TEMPERATURE = 1.0
 
print(f"Environment: {ENV}")
print(f"Model: {BASE_MODEL}")
print(f"Steps: {NUM_STEPS}, Batch: {BATCH_SIZE}, Group: {GROUP_SIZE}")
print(f"Max tokens: {MAX_TOKENS}")

Parameter choices:

  • GROUP_SIZE: Number of responses sampled per problem. More samples = better advantage estimates, slower training
  • TEMPERATURE=1.0: High temperature encourages exploration during RL
  • MAX_TOKENS: Short for arithmetic (just a number), long for math (reasoning chains)

Step 2: Grading Logic

Why Grading Matters

Math has multiple correct representations:

  • 1/2 = 0.5 = \frac{1}{2} = 0.50
  • x^2 + 2x + 1 = (x+1)^2

We need robust grading that recognizes equivalence.

Extract Answers

import re
 
def extract_number(text: str) -> int | None:
    """Extract first integer from text."""
    match = re.search(r'-?\d+', text)
    return int(match.group()) if match else None
 
 
def extract_boxed(text: str) -> str:
    """Extract content from last \\boxed{...} in text."""
    boxed_strs = []
    stack = []
    for i, c in enumerate(text):
        if c == "{":
            stack.append(i)
        elif c == "}" and stack:
            start = stack.pop()
            if text[:start].endswith("\\boxed"):
                boxed_strs.append(text[start + 1:i])
    if boxed_strs:
        return boxed_strs[-1]
    # Try \boxed X without braces
    match = re.search(r"\\boxed\s+([a-zA-Z0-9]+)", text)
    if match:
        return match.group(1)
    raise ValueError("No \\boxed{} found")

Sympy Equivalence Checking

For mathematical expressions, we use sympy to check algebraic equivalence:

import sympy
from sympy.parsing import sympy_parser
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
 
 
def normalize_answer(answer: str) -> str:
    """Normalize math answer for comparison."""
    s = answer.strip()
    s = s.replace("\\tfrac", "\\frac").replace("\\dfrac", "\\frac")
    s = s.replace("\\left", "").replace("\\right", "")
    s = s.replace("^{\\circ}", "").replace("^\\circ", "")
    s = s.replace("\\%", "").replace("\\$", "")
    s = s.replace(" ", "").replace("\n", "")
    s = s.lower()
    return s
 
 
def sympy_equal(a: str, b: str) -> bool:
    """Check if two expressions are mathematically equal using sympy."""
    try:
        a_py = a.replace("^", "**")
        b_py = b.replace("^", "**")
        expr_a = sympy_parser.parse_expr(
            a_py,
            transformations=sympy_parser.standard_transformations +
            (sympy_parser.implicit_multiplication_application,)
        )
        expr_b = sympy_parser.parse_expr(
            b_py,
            transformations=sympy_parser.standard_transformations +
            (sympy_parser.implicit_multiplication_application,)
        )
        return sympy.simplify(expr_a - expr_b) == 0
    except Exception:
        return False
 
 
def grade_with_timeout(given: str, truth: str, timeout: float = 1.0) -> bool:
    """Grade answer with timeout protection for sympy."""
    # Fast path: string match
    if normalize_answer(given) == normalize_answer(truth):
        return True
    # Slow path: sympy
    with ThreadPoolExecutor(max_workers=1) as executor:
        future = executor.submit(sympy_equal, given, truth)
        try:
            return future.result(timeout=timeout)
        except (FuturesTimeoutError, Exception):
            return False

Why timeout? Some expressions cause sympy to hang. The timeout ensures training doesn’t stall.

Grading Functions

def grade_math_answer(response: str, truth: str) -> float:
    """Grade a math response with \\boxed{} format."""
    try:
        given = extract_boxed(response)
        return 1.0 if grade_with_timeout(given, truth) else 0.0
    except ValueError:
        return 0.0
 
 
def grade_arithmetic(response: str, correct: int) -> float:
    """Grade an arithmetic response."""
    extracted = extract_number(response)
    return 1.0 if extracted == correct else 0.0

Test the grading:

print("Grading demos:")
print(f"  grade_with_timeout('1/2', '0.5') = {grade_with_timeout('1/2', '0.5')}")
print(f"  grade_with_timeout('\\frac{2}{4}', '0.5') = {grade_with_timeout('\\frac{2}{4}', '0.5')}")

Step 3: Problem Environments

Problem Definition

import random
from dataclasses import dataclass
from typing import Callable
 
 
@dataclass
class Problem:
    question: str
    answer: str  # Ground truth for grading
    grader: Callable[[str, str], float]  # (response, answer) -> reward

Few-Shot Prefixes

Few-shot examples help the model understand the expected format:

ARITHMETIC_FEWSHOT = """Q: What is 4 + 5?
A: 9
 
"""
 
MATH_FEWSHOT = """Q: How many r's are in strawberry? Write your answer in \\boxed{} format.
A: Let's count: s-t-r-a-w-b-e-r-r-y. The r's are at positions 3, 8, 9. \\boxed{3}
 
"""
 
 
def get_fewshot(env: str) -> str:
    return ARITHMETIC_FEWSHOT if env == "arithmetic" else MATH_FEWSHOT

Generate Arithmetic Problems

def generate_arithmetic_problem() -> Problem:
    """Generate a simple addition problem."""
    x = random.randint(0, 99)
    y = random.randint(0, 99)
    return Problem(
        question=f"Q: What is {x} + {y}?\nA:",
        answer=str(x + y),
        grader=lambda resp, ans: grade_arithmetic(resp, int(ans))
    )

Step 4: Load Dataset

The MathDatasetLoader handles multiple dataset formats:

from datasets import load_dataset
 
 
class MathDatasetLoader:
    """Load and iterate through math datasets."""
 
    def __init__(self, env: str, seed: int = 42):
        self.env = env
        self.seed = seed
        self.data = None
        self.index = 0
 
        if env == "arithmetic":
            self.data = None  # Generated on the fly
        elif env == "gsm8k":
            ds = load_dataset("openai/gsm8k", "main", split="train")
            self.data = ds.shuffle(seed=seed)
        elif env == "math":
            ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
            self.data = ds.shuffle(seed=seed)
        elif env == "polaris":
            ds = load_dataset("POLARIS-Project/Polaris-Dataset-53K", split="train")
            self.data = ds.shuffle(seed=seed)
        elif env == "deepmath":
            ds = load_dataset("zwhe99/DeepMath-103K", split="train")
            self.data = ds.shuffle(seed=seed)
        else:
            raise ValueError(f"Unknown env: {env}")
 
    def get_batch(self, batch_size: int) -> list[Problem]:
        """Get a batch of problems."""
        if self.env == "arithmetic":
            return [generate_arithmetic_problem() for _ in range(batch_size)]
 
        problems = []
        for _ in range(batch_size):
            if self.index >= len(self.data):
                self.index = 0
 
            row = self.data[self.index]
            self.index += 1
 
            problem = self._parse_row(row)
            if problem:
                problems.append(problem)
 
        return problems
 
    def _parse_row(self, row: dict) -> Problem | None:
        """Parse a dataset row into a Problem."""
        try:
            if self.env == "gsm8k":
                question = row["question"]
                answer_text = row["answer"]
                match = re.search(r"####\s*(.+)", answer_text)
                answer = match.group(1).strip().replace(",", "") if match else ""
            elif self.env == "math":
                question = row["problem"]
                answer = extract_boxed(row["solution"])
            elif self.env == "polaris":
                question = row.get("problem", "")
                answer = row.get("answer", "")
            elif self.env == "deepmath":
                question = row.get("question", "")
                answer = row.get("final_answer", "")
            else:
                return None
 
            if not question or not answer:
                return None
 
            suffix = " Write your answer in \\boxed{} format."
            return Problem(
                question=f"Q: {question}{suffix}\nA:",
                answer=answer,
                grader=lambda resp, ans: grade_math_answer(resp, ans)
            )
        except Exception:
            return None

Initialize and test:

dataset = MathDatasetLoader(ENV)
print(f"Dataset loaded for {ENV}")
 
sample = dataset.get_batch(1)[0]
print(f"\nSample problem:")
print(f"  Q: {sample.question[:100]}...")
print(f"  A: {sample.answer}")

Step 5: Create Training Client

training_client = service_client.create_lora_training_client(
    base_model=BASE_MODEL,
    rank=LORA_RANK,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)
print(f"Training client created: {BASE_MODEL}")
 
tokenizer = training_client.get_tokenizer()
print(f"Vocab size: {tokenizer.vocab_size:,}")

Step 6: RL Training Loop

The RL training loop:

for each step:
    1. Get batch of problems
    2. Sample GROUP_SIZE responses per problem
    3. Compute rewards (1=correct, 0=wrong)
    4. Compute advantages (reward - group_mean)
    5. Train with importance_sampling loss

Understanding Advantages

Why advantages instead of raw rewards?

If all samples in a group are correct (reward=1), there’s nothing to learn. Advantages normalize within each group:

  • advantage = reward - mean_reward
  • Positive advantage: better than average, reinforce
  • Negative advantage: worse than average, discourage
  • Zero advantage: skip (no gradient signal)

Training Code

import torch
from mint import TensorData
 
fewshot = get_fewshot(ENV)
metrics_history = []
 
print(f"Starting RL training: {NUM_STEPS} steps")
print(f"Batch: {BATCH_SIZE}, Group: {GROUP_SIZE}, LR: {LEARNING_RATE}")
print()
 
for step in range(NUM_STEPS):
    # Save weights for sampling
    sampler_path = training_client.save_weights_for_sampler(
        name=f"{ENV}-step-{step}"
    ).result().path
 
    sampling_client = service_client.create_sampling_client(
        model_path=sampler_path,
        base_model=BASE_MODEL
    )
 
    # Get problems
    problems = dataset.get_batch(BATCH_SIZE)
 
    training_datums = []
    all_rewards = []
 
    for problem in problems:
        prompt_text = fewshot + problem.question
        prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=True)
        prompt_input = types.ModelInput.from_ints(prompt_tokens)
 
        # Sample responses
        sample_result = sampling_client.sample(
            prompt=prompt_input,
            num_samples=GROUP_SIZE,
            sampling_params=types.SamplingParams(
                max_tokens=MAX_TOKENS,
                temperature=TEMPERATURE,
                stop_token_ids=[tokenizer.eos_token_id]
            )
        ).result()
 
        # Grade responses
        group_rewards = []
        group_responses = []
        group_logprobs = []
 
        for seq in sample_result.sequences:
            response_text = tokenizer.decode(seq.tokens)
            reward = problem.grader(response_text, problem.answer)
            group_rewards.append(reward)
            group_responses.append(list(seq.tokens))
            group_logprobs.append(list(seq.logprobs) if seq.logprobs else [0.0] * len(seq.tokens))
 
        all_rewards.extend(group_rewards)
 
        # Compute advantages
        mean_reward = sum(group_rewards) / len(group_rewards)
        advantages = [r - mean_reward for r in group_rewards]
 
        # Skip if no variance
        if all(a == 0 for a in advantages):
            continue
 
        # Create training datums
        for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
            if len(response_tokens) == 0:
                continue
 
            full_tokens = prompt_tokens + response_tokens
            input_tokens = full_tokens[:-1]
            target_tokens = full_tokens[1:]
 
            weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(response_tokens)
            full_logprobs = [0.0] * (len(prompt_tokens) - 1) + logprobs
            full_advantages = [0.0] * (len(prompt_tokens) - 1) + [adv] * len(response_tokens)
 
            datum = types.Datum(
                model_input=types.ModelInput.from_ints(tokens=input_tokens),
                loss_fn_inputs={
                    "target_tokens": TensorData.from_torch(torch.tensor(target_tokens, dtype=torch.int64)),
                    "weights": TensorData.from_torch(torch.tensor(weights, dtype=torch.float32)),
                    "logprobs": TensorData.from_torch(torch.tensor(full_logprobs, dtype=torch.float32)),
                    "advantages": TensorData.from_torch(torch.tensor(full_advantages, dtype=torch.float32)),
                },
            )
            training_datums.append(datum)
 
    # Metrics
    accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
    avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
 
    # Train
    if training_datums:
        training_client.forward_backward(
            training_datums,
            loss_fn="importance_sampling"
        ).result()
 
        training_client.optim_step(
            types.AdamParams(learning_rate=LEARNING_RATE)
        ).result()
 
    metrics_history.append({
        'step': step,
        'accuracy': accuracy,
        'reward': avg_reward,
        'datums': len(training_datums)
    })
 
    if step % 10 == 0 or step == NUM_STEPS - 1:
        print(f"Step {step:3d}: acc={accuracy:.1%}, reward={avg_reward:.3f}, datums={len(training_datums)}")
 
print("\nTraining complete!")
print(f"Initial: {metrics_history[0]['accuracy']:.1%}")
print(f"Final: {metrics_history[-1]['accuracy']:.1%}")

Step 7: Evaluate

Test the trained model:

final_path = training_client.save_weights_for_sampler(name=f"{ENV}-final").result().path
final_client = service_client.create_sampling_client(
    model_path=final_path,
    base_model=BASE_MODEL
)
 
test_problems = dataset.get_batch(5)
 
print("Evaluation:")
print("=" * 60)
correct = 0
 
for problem in test_problems:
    prompt = fewshot + problem.question
    prompt_input = types.ModelInput.from_ints(tokenizer.encode(prompt))
 
    result = final_client.sample(
        prompt=prompt_input,
        num_samples=1,
        sampling_params=types.SamplingParams(
            max_tokens=MAX_TOKENS,
            temperature=0.0,  # Greedy for evaluation
            stop_token_ids=[tokenizer.eos_token_id]
        )
    ).result()
 
    response = tokenizer.decode(result.sequences[0].tokens)
    reward = problem.grader(response, problem.answer)
    if reward > 0:
        correct += 1
 
    status = "PASS" if reward > 0 else "FAIL"
    print(f"Q: {problem.question[:50]}...")
    print(f"A: {response.strip()[:50]}... (expected: {problem.answer}) [{status}]")
    print()
 
print(f"Test accuracy: {correct}/{len(test_problems)}")

Step 8: Visualize

Plot the accuracy curve:

import matplotlib.pyplot as plt
 
steps = [m['step'] for m in metrics_history]
accuracies = [m['accuracy'] for m in metrics_history]
 
plt.figure(figsize=(10, 5))
plt.plot(steps, accuracies, 'b-', linewidth=2)
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.title(f'{ENV.upper()} RL Training')
plt.ylim(0, 1.05)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{ENV}_training.png', dpi=150)
plt.show()

Step 9: Save Checkpoint

checkpoint = training_client.save_state(name=f"{ENV}-rl-final").result()
print(f"Checkpoint: {checkpoint.path}")

Summary

ComponentImplementation
Environmentarithmetic, gsm8k, math, polaris, deepmath
GradingString normalization + sympy equivalence
Trainingimportance_sampling with group advantages
Checkpointingsave_state() for weights + optimizer

Key API Methods

# Setup
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(base_model=...)
tokenizer = training_client.get_tokenizer()
 
# RL Training
sampling_client = service_client.create_sampling_client(model_path, base_model)
sampling_client.sample(prompt, num_samples, sampling_params)  # Sample responses
training_client.forward_backward(datums, loss_fn="importance_sampling")  # RL loss
training_client.optim_step(types.AdamParams(learning_rate=...))
 
# Checkpointing
checkpoint = training_client.save_state(name=...)