Math Reinforcement Learning
This tutorial demonstrates reinforcement learning (RL) training for math problems using MinT.
What You’ll Learn
- Configure different math environments (arithmetic to competition-level)
- Implement grading logic with sympy equivalence checking
- Run the RL training loop with group-based advantages
- Evaluate and visualize model improvement
Supported Environments
| Environment | Dataset | Problems | Difficulty | Use Case |
|---|---|---|---|---|
arithmetic | Generated | X + Y | Easy | Debug pipeline |
gsm8k | OpenAI GSM8K | Grade school word problems | Medium | Baseline training |
math | Hendrycks MATH | Competition math | Hard | Advanced reasoning |
polaris | POLARIS-53K | Difficulty-calibrated | Mixed | Optimal training |
deepmath | DeepMath-103K | Large-scale verifiable | Mixed | Production training |
RL vs SFT
| Method | Training Signal | Advantage |
|---|---|---|
| SFT | Memorize exact answer format | Simple, fast |
| RL | Learn from correctness signal | Handles equivalent answers |
RL excels at math because equivalent answers (1/2, 0.5, \frac{1}{2}) all receive positive reward.
Dataset Details
GSM8K (Grade School Math 8K)
Source: OpenAI (2021) | Size: 7.5K train + 1K test
Multi-step grade school word problems requiring 2-8 calculation steps. Only basic arithmetic needed.
Q: Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes 4 into muffins.
She sells the rest at $2 each. How much does she make daily?
A: 16 - 3 - 4 = 9 eggs sold. 9 * 2 = $18. #### 18Hendrycks MATH
Source: UC Berkeley (NeurIPS 2021) | Size: 12K train + 500 test
High school competition problems (AMC 10/12, AIME) across 7 subjects with LaTeX solutions.
Q: Find the sum of all positive integers n such that n^2 + 12n - 2007 is a perfect square.
A: \boxed{80}DeepMath-103K
Source: Tencent + SJTU (2025) | Size: 103K
Large-scale dataset for RL training with verifiable answers. Each problem has 3 different solutions.
POLARIS-53K
Source: HKU NLP + ByteDance (2025) | Size: 53K
Difficulty-calibrated filtering based on target model capability. Removes problems the model already solves perfectly to create optimal training distribution.
Recommended Progression
| Stage | Dataset | Purpose |
|---|---|---|
| 1. Debug | arithmetic | Verify RL pipeline works |
| 2. Baseline | gsm8k | Simple math, quick results |
| 3. Production | math / deepmath / polaris | Train advanced reasoning |
Step 0: Setup
Install required packages:
pip install -q datasets sympy pylatexenc mintLoad your API key:
import os
from dotenv import load_dotenv
load_dotenv()
if os.environ.get('MINT_API_KEY'):
print("API key loaded")
else:
print("WARNING: MINT_API_KEY not found!")Connect to MinT:
import mint
from mint import types
service_client = mint.ServiceClient()
print("Connected to MinT")Step 1: Configuration
Configure your training run:
# ========== CONFIGURATION ==========
# Environment: "arithmetic", "gsm8k", "math", "polaris", "deepmath"
ENV = "arithmetic"
# Model
BASE_MODEL = "Qwen/Qwen3-0.6B"
LORA_RANK = 16
# Training
NUM_STEPS = 100
BATCH_SIZE = 100 if ENV == "arithmetic" else 32
GROUP_SIZE = 4
LEARNING_RATE = 1e-4 if ENV == "arithmetic" else 1e-5
# Generation
MAX_TOKENS = 8 if ENV == "arithmetic" else 512
TEMPERATURE = 1.0
print(f"Environment: {ENV}")
print(f"Model: {BASE_MODEL}")
print(f"Steps: {NUM_STEPS}, Batch: {BATCH_SIZE}, Group: {GROUP_SIZE}")
print(f"Max tokens: {MAX_TOKENS}")Parameter choices:
GROUP_SIZE: Number of responses sampled per problem. More samples = better advantage estimates, slower trainingTEMPERATURE=1.0: High temperature encourages exploration during RLMAX_TOKENS: Short for arithmetic (just a number), long for math (reasoning chains)
Step 2: Grading Logic
Why Grading Matters
Math has multiple correct representations:
1/2=0.5=\frac{1}{2}=0.50x^2 + 2x + 1=(x+1)^2
We need robust grading that recognizes equivalence.
Extract Answers
import re
def extract_number(text: str) -> int | None:
"""Extract first integer from text."""
match = re.search(r'-?\d+', text)
return int(match.group()) if match else None
def extract_boxed(text: str) -> str:
"""Extract content from last \\boxed{...} in text."""
boxed_strs = []
stack = []
for i, c in enumerate(text):
if c == "{":
stack.append(i)
elif c == "}" and stack:
start = stack.pop()
if text[:start].endswith("\\boxed"):
boxed_strs.append(text[start + 1:i])
if boxed_strs:
return boxed_strs[-1]
# Try \boxed X without braces
match = re.search(r"\\boxed\s+([a-zA-Z0-9]+)", text)
if match:
return match.group(1)
raise ValueError("No \\boxed{} found")Sympy Equivalence Checking
For mathematical expressions, we use sympy to check algebraic equivalence:
import sympy
from sympy.parsing import sympy_parser
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
def normalize_answer(answer: str) -> str:
"""Normalize math answer for comparison."""
s = answer.strip()
s = s.replace("\\tfrac", "\\frac").replace("\\dfrac", "\\frac")
s = s.replace("\\left", "").replace("\\right", "")
s = s.replace("^{\\circ}", "").replace("^\\circ", "")
s = s.replace("\\%", "").replace("\\$", "")
s = s.replace(" ", "").replace("\n", "")
s = s.lower()
return s
def sympy_equal(a: str, b: str) -> bool:
"""Check if two expressions are mathematically equal using sympy."""
try:
a_py = a.replace("^", "**")
b_py = b.replace("^", "**")
expr_a = sympy_parser.parse_expr(
a_py,
transformations=sympy_parser.standard_transformations +
(sympy_parser.implicit_multiplication_application,)
)
expr_b = sympy_parser.parse_expr(
b_py,
transformations=sympy_parser.standard_transformations +
(sympy_parser.implicit_multiplication_application,)
)
return sympy.simplify(expr_a - expr_b) == 0
except Exception:
return False
def grade_with_timeout(given: str, truth: str, timeout: float = 1.0) -> bool:
"""Grade answer with timeout protection for sympy."""
# Fast path: string match
if normalize_answer(given) == normalize_answer(truth):
return True
# Slow path: sympy
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sympy_equal, given, truth)
try:
return future.result(timeout=timeout)
except (FuturesTimeoutError, Exception):
return FalseWhy timeout? Some expressions cause sympy to hang. The timeout ensures training doesn’t stall.
Grading Functions
def grade_math_answer(response: str, truth: str) -> float:
"""Grade a math response with \\boxed{} format."""
try:
given = extract_boxed(response)
return 1.0 if grade_with_timeout(given, truth) else 0.0
except ValueError:
return 0.0
def grade_arithmetic(response: str, correct: int) -> float:
"""Grade an arithmetic response."""
extracted = extract_number(response)
return 1.0 if extracted == correct else 0.0Test the grading:
print("Grading demos:")
print(f" grade_with_timeout('1/2', '0.5') = {grade_with_timeout('1/2', '0.5')}")
print(f" grade_with_timeout('\\frac{2}{4}', '0.5') = {grade_with_timeout('\\frac{2}{4}', '0.5')}")Step 3: Problem Environments
Problem Definition
import random
from dataclasses import dataclass
from typing import Callable
@dataclass
class Problem:
question: str
answer: str # Ground truth for grading
grader: Callable[[str, str], float] # (response, answer) -> rewardFew-Shot Prefixes
Few-shot examples help the model understand the expected format:
ARITHMETIC_FEWSHOT = """Q: What is 4 + 5?
A: 9
"""
MATH_FEWSHOT = """Q: How many r's are in strawberry? Write your answer in \\boxed{} format.
A: Let's count: s-t-r-a-w-b-e-r-r-y. The r's are at positions 3, 8, 9. \\boxed{3}
"""
def get_fewshot(env: str) -> str:
return ARITHMETIC_FEWSHOT if env == "arithmetic" else MATH_FEWSHOTGenerate Arithmetic Problems
def generate_arithmetic_problem() -> Problem:
"""Generate a simple addition problem."""
x = random.randint(0, 99)
y = random.randint(0, 99)
return Problem(
question=f"Q: What is {x} + {y}?\nA:",
answer=str(x + y),
grader=lambda resp, ans: grade_arithmetic(resp, int(ans))
)Step 4: Load Dataset
The MathDatasetLoader handles multiple dataset formats:
from datasets import load_dataset
class MathDatasetLoader:
"""Load and iterate through math datasets."""
def __init__(self, env: str, seed: int = 42):
self.env = env
self.seed = seed
self.data = None
self.index = 0
if env == "arithmetic":
self.data = None # Generated on the fly
elif env == "gsm8k":
ds = load_dataset("openai/gsm8k", "main", split="train")
self.data = ds.shuffle(seed=seed)
elif env == "math":
ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
self.data = ds.shuffle(seed=seed)
elif env == "polaris":
ds = load_dataset("POLARIS-Project/Polaris-Dataset-53K", split="train")
self.data = ds.shuffle(seed=seed)
elif env == "deepmath":
ds = load_dataset("zwhe99/DeepMath-103K", split="train")
self.data = ds.shuffle(seed=seed)
else:
raise ValueError(f"Unknown env: {env}")
def get_batch(self, batch_size: int) -> list[Problem]:
"""Get a batch of problems."""
if self.env == "arithmetic":
return [generate_arithmetic_problem() for _ in range(batch_size)]
problems = []
for _ in range(batch_size):
if self.index >= len(self.data):
self.index = 0
row = self.data[self.index]
self.index += 1
problem = self._parse_row(row)
if problem:
problems.append(problem)
return problems
def _parse_row(self, row: dict) -> Problem | None:
"""Parse a dataset row into a Problem."""
try:
if self.env == "gsm8k":
question = row["question"]
answer_text = row["answer"]
match = re.search(r"####\s*(.+)", answer_text)
answer = match.group(1).strip().replace(",", "") if match else ""
elif self.env == "math":
question = row["problem"]
answer = extract_boxed(row["solution"])
elif self.env == "polaris":
question = row.get("problem", "")
answer = row.get("answer", "")
elif self.env == "deepmath":
question = row.get("question", "")
answer = row.get("final_answer", "")
else:
return None
if not question or not answer:
return None
suffix = " Write your answer in \\boxed{} format."
return Problem(
question=f"Q: {question}{suffix}\nA:",
answer=answer,
grader=lambda resp, ans: grade_math_answer(resp, ans)
)
except Exception:
return NoneInitialize and test:
dataset = MathDatasetLoader(ENV)
print(f"Dataset loaded for {ENV}")
sample = dataset.get_batch(1)[0]
print(f"\nSample problem:")
print(f" Q: {sample.question[:100]}...")
print(f" A: {sample.answer}")Step 5: Create Training Client
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL,
rank=LORA_RANK,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
print(f"Training client created: {BASE_MODEL}")
tokenizer = training_client.get_tokenizer()
print(f"Vocab size: {tokenizer.vocab_size:,}")Step 6: RL Training Loop
The RL training loop:
for each step:
1. Get batch of problems
2. Sample GROUP_SIZE responses per problem
3. Compute rewards (1=correct, 0=wrong)
4. Compute advantages (reward - group_mean)
5. Train with importance_sampling lossUnderstanding Advantages
Why advantages instead of raw rewards?
If all samples in a group are correct (reward=1), there’s nothing to learn. Advantages normalize within each group:
advantage = reward - mean_reward- Positive advantage: better than average, reinforce
- Negative advantage: worse than average, discourage
- Zero advantage: skip (no gradient signal)
Training Code
import torch
from mint import TensorData
fewshot = get_fewshot(ENV)
metrics_history = []
print(f"Starting RL training: {NUM_STEPS} steps")
print(f"Batch: {BATCH_SIZE}, Group: {GROUP_SIZE}, LR: {LEARNING_RATE}")
print()
for step in range(NUM_STEPS):
# Save weights for sampling
sampler_path = training_client.save_weights_for_sampler(
name=f"{ENV}-step-{step}"
).result().path
sampling_client = service_client.create_sampling_client(
model_path=sampler_path,
base_model=BASE_MODEL
)
# Get problems
problems = dataset.get_batch(BATCH_SIZE)
training_datums = []
all_rewards = []
for problem in problems:
prompt_text = fewshot + problem.question
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=True)
prompt_input = types.ModelInput.from_ints(prompt_tokens)
# Sample responses
sample_result = sampling_client.sample(
prompt=prompt_input,
num_samples=GROUP_SIZE,
sampling_params=types.SamplingParams(
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
stop_token_ids=[tokenizer.eos_token_id]
)
).result()
# Grade responses
group_rewards = []
group_responses = []
group_logprobs = []
for seq in sample_result.sequences:
response_text = tokenizer.decode(seq.tokens)
reward = problem.grader(response_text, problem.answer)
group_rewards.append(reward)
group_responses.append(list(seq.tokens))
group_logprobs.append(list(seq.logprobs) if seq.logprobs else [0.0] * len(seq.tokens))
all_rewards.extend(group_rewards)
# Compute advantages
mean_reward = sum(group_rewards) / len(group_rewards)
advantages = [r - mean_reward for r in group_rewards]
# Skip if no variance
if all(a == 0 for a in advantages):
continue
# Create training datums
for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
if len(response_tokens) == 0:
continue
full_tokens = prompt_tokens + response_tokens
input_tokens = full_tokens[:-1]
target_tokens = full_tokens[1:]
weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(response_tokens)
full_logprobs = [0.0] * (len(prompt_tokens) - 1) + logprobs
full_advantages = [0.0] * (len(prompt_tokens) - 1) + [adv] * len(response_tokens)
datum = types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
"target_tokens": TensorData.from_torch(torch.tensor(target_tokens, dtype=torch.int64)),
"weights": TensorData.from_torch(torch.tensor(weights, dtype=torch.float32)),
"logprobs": TensorData.from_torch(torch.tensor(full_logprobs, dtype=torch.float32)),
"advantages": TensorData.from_torch(torch.tensor(full_advantages, dtype=torch.float32)),
},
)
training_datums.append(datum)
# Metrics
accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
# Train
if training_datums:
training_client.forward_backward(
training_datums,
loss_fn="importance_sampling"
).result()
training_client.optim_step(
types.AdamParams(learning_rate=LEARNING_RATE)
).result()
metrics_history.append({
'step': step,
'accuracy': accuracy,
'reward': avg_reward,
'datums': len(training_datums)
})
if step % 10 == 0 or step == NUM_STEPS - 1:
print(f"Step {step:3d}: acc={accuracy:.1%}, reward={avg_reward:.3f}, datums={len(training_datums)}")
print("\nTraining complete!")
print(f"Initial: {metrics_history[0]['accuracy']:.1%}")
print(f"Final: {metrics_history[-1]['accuracy']:.1%}")Step 7: Evaluate
Test the trained model:
final_path = training_client.save_weights_for_sampler(name=f"{ENV}-final").result().path
final_client = service_client.create_sampling_client(
model_path=final_path,
base_model=BASE_MODEL
)
test_problems = dataset.get_batch(5)
print("Evaluation:")
print("=" * 60)
correct = 0
for problem in test_problems:
prompt = fewshot + problem.question
prompt_input = types.ModelInput.from_ints(tokenizer.encode(prompt))
result = final_client.sample(
prompt=prompt_input,
num_samples=1,
sampling_params=types.SamplingParams(
max_tokens=MAX_TOKENS,
temperature=0.0, # Greedy for evaluation
stop_token_ids=[tokenizer.eos_token_id]
)
).result()
response = tokenizer.decode(result.sequences[0].tokens)
reward = problem.grader(response, problem.answer)
if reward > 0:
correct += 1
status = "PASS" if reward > 0 else "FAIL"
print(f"Q: {problem.question[:50]}...")
print(f"A: {response.strip()[:50]}... (expected: {problem.answer}) [{status}]")
print()
print(f"Test accuracy: {correct}/{len(test_problems)}")Step 8: Visualize
Plot the accuracy curve:
import matplotlib.pyplot as plt
steps = [m['step'] for m in metrics_history]
accuracies = [m['accuracy'] for m in metrics_history]
plt.figure(figsize=(10, 5))
plt.plot(steps, accuracies, 'b-', linewidth=2)
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.title(f'{ENV.upper()} RL Training')
plt.ylim(0, 1.05)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{ENV}_training.png', dpi=150)
plt.show()Step 9: Save Checkpoint
checkpoint = training_client.save_state(name=f"{ENV}-rl-final").result()
print(f"Checkpoint: {checkpoint.path}")Summary
| Component | Implementation |
|---|---|
| Environment | arithmetic, gsm8k, math, polaris, deepmath |
| Grading | String normalization + sympy equivalence |
| Training | importance_sampling with group advantages |
| Checkpointing | save_state() for weights + optimizer |
Key API Methods
# Setup
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(base_model=...)
tokenizer = training_client.get_tokenizer()
# RL Training
sampling_client = service_client.create_sampling_client(model_path, base_model)
sampling_client.sample(prompt, num_samples, sampling_params) # Sample responses
training_client.forward_backward(datums, loss_fn="importance_sampling") # RL loss
training_client.optim_step(types.AdamParams(learning_rate=...))
# Checkpointing
checkpoint = training_client.save_state(name=...)