Code RL
这一页对应 mint-quickstart 中的 demos/rl/adapters/environment_tooluse.py。
这个 demo 做什么
- 这是 RL-3 Environment Tool Use:通过共享
demos/rl/rl_core.py循环 + 任务 adapter 运行训练。 - 训练 LoRA 生成小段 Python 函数;reward 基于执行结果打分:测试全通过得
1.0,否则0.0。 - 使用 few-shot + 组采样计算 group-relative advantage,再通过
importance_sampling更新策略。
期望输出
- 会先打印
Model: ...,每一步打印Step N: accuracy=...、datums=...,最后输出Saved: ...。
常见坑
- 评分会对模型输出做
exec/eval(玩具 demo):不要对不可信模型/输入运行,建议在沙盒里扩展。 datums=0代表同组样本 reward 完全一致(全挂/全过)被跳过;可调大MINT_RL_GROUP/MINT_RL_BATCH,提高MINT_RL_TEMPERATURE,或让题更简单。- 模型不按代码块输出(或被截断)会导致
_extract_code()抽不到代码、reward=0;可调整 few-shot/提示词,或增大MINT_RL_MAX_TOKENS。
前置条件
- Python >= 3.11
- 设置
MINT_API_KEY(或用.env配置)
运行方式
export MINT_API_KEY=sk-...
python demos/rl/adapters/environment_tooluse.py参数(环境变量)
MINT_BASE_MODEL:默认Qwen/Qwen3-0.6BMINT_LORA_RANK:默认16MINT_RL_STEPS:默认10MINT_RL_BATCH:默认8MINT_RL_GROUP:默认4MINT_RL_LR:默认1e-4MINT_RL_MAX_TOKENS:默认256MINT_RL_TEMPERATURE:默认1.0
完整脚本
#!/usr/bin/env python3
"""RL-3 Environment Tool Use — adapter for rl_core.
Reward: execution-based grading (generated code passes test cases = 1.0).
Run: python demos/rl/adapters/environment_tooluse.py
"""
from __future__ import annotations
import os
import random
import re
import sys
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo # noqa: E402
FEWSHOT = """Q: Write a function `double(x)` that returns x * 2.
A: ```python
def double(x):
return x * 2
```
"""
PROBLEMS = [
{"q": "Write `add(a, b)` that returns a + b.", "tests": [("add(1,2)", 3), ("add(-1,1)", 0)]},
{"q": "Write `square(x)` that returns x squared.", "tests": [("square(3)", 9), ("square(-2)", 4)]},
{"q": "Write `max2(a, b)` that returns the larger.", "tests": [("max2(3,5)", 5), ("max2(7,2)", 7)]},
{"q": "Write `is_even(n)` returning True if even.", "tests": [("is_even(4)", True), ("is_even(7)", False)]},
{"q": "Write `abs_val(x)` returning absolute value.", "tests": [("abs_val(-5)", 5), ("abs_val(3)", 3)]},
]
def _extract_code(response: str) -> str | None:
match = re.findall(r"```(?:\w+)?\n(.*?)```", response, re.DOTALL)
if match:
return match[-1].strip()
if "def " in response:
return response[response.find("def "):].strip()
return None
class EnvironmentToolUseAdapter(RLAdapter):
def build_dataset(self) -> list[dict]:
return [random.choice(PROBLEMS) for _ in range(50)]
def make_prompt(self, sample: dict, tokenizer: Any) -> list[int]:
return tokenizer.encode(FEWSHOT + f"Q: {sample['q']}\nA:")
def compute_reward(self, response: str, sample: dict) -> float:
code = _extract_code(response)
if not code:
return 0.0
try:
ns: dict[str, Any] = {}
exec(code, ns) # noqa: S102
for expr, expected in sample["tests"]:
if eval(expr, ns) != expected: # noqa: S307
return 0.0
return 1.0
except Exception:
return 0.0
def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
accuracy = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0
print(f"Step {step}: accuracy={accuracy:.1%}, datums={num_datums}")
if __name__ == "__main__":
cfg = RLConfig.from_env()
cfg.steps = int(os.environ.get("MINT_RL_STEPS", "10"))
cfg.batch = int(os.environ.get("MINT_RL_BATCH", "8"))
cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "256"))
run_grpo(EnvironmentToolUseAdapter(), cfg)下一步
- 最后会打印
Saved: <path>;把这个 path 当成model_path用于创建 sampling client:
import mint
service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<粘贴 Saved path>")- 采样与 tokenizer 相关说明见:
/using-the-api/saving-and-loading、/api-reference/sampling-client。