Math RL
这一页对应 mint-quickstart 中的 demos/rl/adapters/verifiable_math.py。
这个 demo 做什么
- 这是 RL-1 Verifiable Math:通过共享
demos/rl/rl_core.py循环 + 任务 adapter 运行训练。 - 在 2 位数加法题上训练 LoRA,reward 为可验证打分:答对
1.0,答错0.0。 - 对组采样结果做 group-relative advantage,并用
importance_sampling更新。
期望输出
- 会先打印
Model: ...,每一步打印Step N: accuracy=...、datums=...,最后输出Saved: ...。
常见坑
- 连续出现
datums=0通常是同一组采样 reward 全一样(advantage 全 0)被跳过;可调大MINT_RL_GROUP/MINT_RL_BATCH,提高MINT_RL_TEMPERATURE,或增加步数。 - 如果长期
accuracy=0.0%,先增大MINT_RL_MAX_TOKENS(默认值故意较小以便快速演示)。 - 若要更严格评测,可将
compute_reward()换成更贴近任务的 verifier。
前置条件
- Python >= 3.11
- 设置
MINT_API_KEY(或用.env配置)
运行方式
export MINT_API_KEY=sk-...
python demos/rl/adapters/verifiable_math.py参数(环境变量)
MINT_BASE_MODEL:默认Qwen/Qwen3-0.6BMINT_LORA_RANK:默认16MINT_RL_STEPS:默认5MINT_RL_BATCH:默认10MINT_RL_GROUP:默认4MINT_RL_LR:默认1e-4MINT_RL_MAX_TOKENS:默认8MINT_RL_TEMPERATURE:默认1.0
完整脚本
#!/usr/bin/env python3
"""RL-1 Verifiable Math — adapter for rl_core.
Reward: exact-match integer grading (1.0 correct, 0.0 wrong).
Run: python demos/rl/adapters/verifiable_math.py
"""
from __future__ import annotations
import os
import random
import re
import sys
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo # noqa: E402
class VerifiableMathAdapter(RLAdapter):
FEWSHOT = "Q: What is 4 + 5?\nA: 9\n\n"
def build_dataset(self) -> list[tuple[str, int]]:
return [
(f"Q: What is {random.randint(0, 99)} + {random.randint(0, 99)}?\nA:", a + b)
for a, b in [(random.randint(0, 99), random.randint(0, 99)) for _ in range(50)]
]
def make_prompt(self, sample: tuple[str, int], tokenizer: Any) -> list[int]:
question, _ = sample
return tokenizer.encode(self.FEWSHOT + question)
def compute_reward(self, response: str, sample: tuple[str, int]) -> float:
_, answer = sample
match = re.search(r"-?\d+", response)
return 1.0 if match and int(match.group()) == answer else 0.0
def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
accuracy = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0
print(f"Step {step}: accuracy={accuracy:.1%}, datums={num_datums}")
if __name__ == "__main__":
cfg = RLConfig.from_env()
cfg.steps = int(os.environ.get("MINT_RL_STEPS", "5"))
cfg.batch = int(os.environ.get("MINT_RL_BATCH", "10"))
cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "8"))
run_grpo(VerifiableMathAdapter(), cfg)下一步
- 最后会打印
Saved: <path>;把这个 path 当成model_path用于创建 sampling client:
import mint
service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<粘贴 Saved path>")- 采样与 tokenizer 相关说明见:
/using-the-api/saving-and-loading、/api-reference/sampling-client。