Mind Lab Toolkit (MinT)

Code RL

这一页对应 mint-quickstart 中的 demos/rl/adapters/environment_tooluse.py

这个 demo 做什么

  • 这是 RL-3 Environment Tool Use:通过共享 demos/rl/rl_core.py 循环 + 任务 adapter 运行训练。
  • 训练 LoRA 生成小段 Python 函数;reward 基于执行结果打分:测试全通过得 1.0,否则 0.0
  • 使用 few-shot + 组采样计算 group-relative advantage,再通过 importance_sampling 更新策略。

期望输出

  • 会先打印 Model: ...,每一步打印 Step N: accuracy=...datums=...,最后输出 Saved: ...

常见坑

  • 评分会对模型输出做 exec/eval(玩具 demo):不要对不可信模型/输入运行,建议在沙盒里扩展。
  • datums=0 代表同组样本 reward 完全一致(全挂/全过)被跳过;可调大 MINT_RL_GROUP/MINT_RL_BATCH,提高 MINT_RL_TEMPERATURE,或让题更简单。
  • 模型不按代码块输出(或被截断)会导致 _extract_code() 抽不到代码、reward=0;可调整 few-shot/提示词,或增大 MINT_RL_MAX_TOKENS

前置条件

  • Python >= 3.11
  • 设置 MINT_API_KEY(或用 .env 配置)

运行方式

export MINT_API_KEY=sk-...
python demos/rl/adapters/environment_tooluse.py

参数(环境变量)

  • MINT_BASE_MODEL:默认 Qwen/Qwen3-0.6B
  • MINT_LORA_RANK:默认 16
  • MINT_RL_STEPS:默认 10
  • MINT_RL_BATCH:默认 8
  • MINT_RL_GROUP:默认 4
  • MINT_RL_LR:默认 1e-4
  • MINT_RL_MAX_TOKENS:默认 256
  • MINT_RL_TEMPERATURE:默认 1.0

完整脚本

#!/usr/bin/env python3
"""RL-3 Environment Tool Use — adapter for rl_core.

Reward: execution-based grading (generated code passes test cases = 1.0).
Run:  python demos/rl/adapters/environment_tooluse.py
"""

from __future__ import annotations

import os
import random
import re
import sys
from pathlib import Path
from typing import Any

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo  # noqa: E402

FEWSHOT = """Q: Write a function `double(x)` that returns x * 2.
A: ```python
def double(x):
    return x * 2
```

"""

PROBLEMS = [
    {"q": "Write `add(a, b)` that returns a + b.", "tests": [("add(1,2)", 3), ("add(-1,1)", 0)]},
    {"q": "Write `square(x)` that returns x squared.", "tests": [("square(3)", 9), ("square(-2)", 4)]},
    {"q": "Write `max2(a, b)` that returns the larger.", "tests": [("max2(3,5)", 5), ("max2(7,2)", 7)]},
    {"q": "Write `is_even(n)` returning True if even.", "tests": [("is_even(4)", True), ("is_even(7)", False)]},
    {"q": "Write `abs_val(x)` returning absolute value.", "tests": [("abs_val(-5)", 5), ("abs_val(3)", 3)]},
]


def _extract_code(response: str) -> str | None:
    match = re.findall(r"```(?:\w+)?\n(.*?)```", response, re.DOTALL)
    if match:
        return match[-1].strip()
    if "def " in response:
        return response[response.find("def "):].strip()
    return None


class EnvironmentToolUseAdapter(RLAdapter):

    def build_dataset(self) -> list[dict]:
        return [random.choice(PROBLEMS) for _ in range(50)]

    def make_prompt(self, sample: dict, tokenizer: Any) -> list[int]:
        return tokenizer.encode(FEWSHOT + f"Q: {sample['q']}\nA:")

    def compute_reward(self, response: str, sample: dict) -> float:
        code = _extract_code(response)
        if not code:
            return 0.0
        try:
            ns: dict[str, Any] = {}
            exec(code, ns)  # noqa: S102
            for expr, expected in sample["tests"]:
                if eval(expr, ns) != expected:  # noqa: S307
                    return 0.0
            return 1.0
        except Exception:
            return 0.0

    def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
        accuracy = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0
        print(f"Step {step}: accuracy={accuracy:.1%}, datums={num_datums}")


if __name__ == "__main__":
    cfg = RLConfig.from_env()
    cfg.steps = int(os.environ.get("MINT_RL_STEPS", "10"))
    cfg.batch = int(os.environ.get("MINT_RL_BATCH", "8"))
    cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "256"))
    run_grpo(EnvironmentToolUseAdapter(), cfg)

下一步

  • 最后会打印 Saved: <path>;把这个 path 当成 model_path 用于创建 sampling client:
import mint

service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<粘贴 Saved path>")
  • 采样与 tokenizer 相关说明见:/using-the-api/saving-and-loading/api-reference/sampling-client

本页目录