Mind Lab Toolkit (MinT)

Math RL

这一页对应 mint-quickstart 中的 demos/rl/adapters/verifiable_math.py

这个 demo 做什么

  • 这是 RL-1 Verifiable Math:通过共享 demos/rl/rl_core.py 循环 + 任务 adapter 运行训练。
  • 在 2 位数加法题上训练 LoRA,reward 为可验证打分:答对 1.0,答错 0.0
  • 对组采样结果做 group-relative advantage,并用 importance_sampling 更新。

期望输出

  • 会先打印 Model: ...,每一步打印 Step N: accuracy=...datums=...,最后输出 Saved: ...

常见坑

  • 连续出现 datums=0 通常是同一组采样 reward 全一样(advantage 全 0)被跳过;可调大 MINT_RL_GROUP/MINT_RL_BATCH,提高 MINT_RL_TEMPERATURE,或增加步数。
  • 如果长期 accuracy=0.0%,先增大 MINT_RL_MAX_TOKENS(默认值故意较小以便快速演示)。
  • 若要更严格评测,可将 compute_reward() 换成更贴近任务的 verifier。

前置条件

  • Python >= 3.11
  • 设置 MINT_API_KEY(或用 .env 配置)

运行方式

export MINT_API_KEY=sk-...
python demos/rl/adapters/verifiable_math.py

参数(环境变量)

  • MINT_BASE_MODEL:默认 Qwen/Qwen3-0.6B
  • MINT_LORA_RANK:默认 16
  • MINT_RL_STEPS:默认 5
  • MINT_RL_BATCH:默认 10
  • MINT_RL_GROUP:默认 4
  • MINT_RL_LR:默认 1e-4
  • MINT_RL_MAX_TOKENS:默认 8
  • MINT_RL_TEMPERATURE:默认 1.0

完整脚本

#!/usr/bin/env python3
"""RL-1 Verifiable Math — adapter for rl_core.

Reward: exact-match integer grading (1.0 correct, 0.0 wrong).
Run:  python demos/rl/adapters/verifiable_math.py
"""

from __future__ import annotations

import os
import random
import re
import sys
from pathlib import Path
from typing import Any

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo  # noqa: E402


class VerifiableMathAdapter(RLAdapter):
    FEWSHOT = "Q: What is 4 + 5?\nA: 9\n\n"

    def build_dataset(self) -> list[tuple[str, int]]:
        return [
            (f"Q: What is {random.randint(0, 99)} + {random.randint(0, 99)}?\nA:", a + b)
            for a, b in [(random.randint(0, 99), random.randint(0, 99)) for _ in range(50)]
        ]

    def make_prompt(self, sample: tuple[str, int], tokenizer: Any) -> list[int]:
        question, _ = sample
        return tokenizer.encode(self.FEWSHOT + question)

    def compute_reward(self, response: str, sample: tuple[str, int]) -> float:
        _, answer = sample
        match = re.search(r"-?\d+", response)
        return 1.0 if match and int(match.group()) == answer else 0.0

    def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
        accuracy = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0
        print(f"Step {step}: accuracy={accuracy:.1%}, datums={num_datums}")


if __name__ == "__main__":
    cfg = RLConfig.from_env()
    cfg.steps = int(os.environ.get("MINT_RL_STEPS", "5"))
    cfg.batch = int(os.environ.get("MINT_RL_BATCH", "10"))
    cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "8"))
    run_grpo(VerifiableMathAdapter(), cfg)

下一步

  • 最后会打印 Saved: <path>;把这个 path 当成 model_path 用于创建 sampling client:
import mint

service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<粘贴 Saved path>")
  • 采样与 tokenizer 相关说明见:/using-the-api/saving-and-loading/api-reference/sampling-client

本页目录