Code RL
This page documents demos/rl/adapters/environment_tooluse.py in mint-quickstart.
What this demo does
- Runs RL-3 Environment Tool Use via the shared
demos/rl/rl_core.pyGRPO loop and a task-specific adapter. - Trains a LoRA to emit small Python functions; reward is execution-based grading (
1.0if tests pass, else0.0). - Uses few-shot prompts plus grouped sampling to compute group-relative advantages and optimize with
importance_sampling.
Expected output
- Prints
Model: ...and per-stepStep N: accuracy=...+datums=...; finishes withSaved: ....
Common gotchas
- Grading runs generated code via
exec/eval(toy demo): do not run on untrusted models/inputs; use a sandbox if you extend it. datums=0means all samples in a group got identical reward (all fail/all pass); increaseMINT_RL_GROUP/temperature/batch, or simplify tasks.- If the model doesn’t output a code fence /
def ...block (or gets truncated),_extract_code()returnsNone; tweak the few-shot/prompt or raiseMINT_RL_MAX_TOKENS.
Prerequisites
- Python >= 3.11
MINT_API_KEYset (or configured via.env)
How to run
export MINT_API_KEY=sk-mint-...
python demos/rl/adapters/environment_tooluse.pyParameters (env vars)
MINT_BASE_MODEL: defaultQwen/Qwen3-0.6BMINT_LORA_RANK: default16MINT_RL_STEPS: default10MINT_RL_BATCH: default8MINT_RL_GROUP: default4MINT_RL_LR: default1e-4MINT_RL_MAX_TOKENS: default256MINT_RL_TEMPERATURE: default1.0
Full script
#!/usr/bin/env python3
"""RL-3 Environment Tool Use — adapter for rl_core.
Reward: execution-based grading (generated code passes test cases = 1.0).
Run: python demos/rl/adapters/environment_tooluse.py
"""
from __future__ import annotations
import os
import random
import re
import sys
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo # noqa: E402
FEWSHOT = """Q: Write a function `double(x)` that returns x * 2.
A: ```python
def double(x):
return x * 2
```
"""
PROBLEMS = [
{"q": "Write `add(a, b)` that returns a + b.", "tests": [("add(1,2)", 3), ("add(-1,1)", 0)]},
{"q": "Write `square(x)` that returns x squared.", "tests": [("square(3)", 9), ("square(-2)", 4)]},
{"q": "Write `max2(a, b)` that returns the larger.", "tests": [("max2(3,5)", 5), ("max2(7,2)", 7)]},
{"q": "Write `is_even(n)` returning True if even.", "tests": [("is_even(4)", True), ("is_even(7)", False)]},
{"q": "Write `abs_val(x)` returning absolute value.", "tests": [("abs_val(-5)", 5), ("abs_val(3)", 3)]},
]
def _extract_code(response: str) -> str | None:
match = re.findall(r"```(?:\w+)?\n(.*?)```", response, re.DOTALL)
if match:
return match[-1].strip()
if "def " in response:
return response[response.find("def "):].strip()
return None
class EnvironmentToolUseAdapter(RLAdapter):
def build_dataset(self) -> list[dict]:
return [random.choice(PROBLEMS) for _ in range(50)]
def make_prompt(self, sample: dict, tokenizer: Any) -> list[int]:
return tokenizer.encode(FEWSHOT + f"Q: {sample['q']}\nA:")
def compute_reward(self, response: str, sample: dict) -> float:
code = _extract_code(response)
if not code:
return 0.0
try:
ns: dict[str, Any] = {}
exec(code, ns) # noqa: S102
for expr, expected in sample["tests"]:
if eval(expr, ns) != expected: # noqa: S307
return 0.0
return 1.0
except Exception:
return 0.0
def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
accuracy = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0
print(f"Step {step}: accuracy={accuracy:.1%}, datums={num_datums}")
if __name__ == "__main__":
cfg = RLConfig.from_env()
cfg.steps = int(os.environ.get("MINT_RL_STEPS", "10"))
cfg.batch = int(os.environ.get("MINT_RL_BATCH", "8"))
cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "256"))
run_grpo(EnvironmentToolUseAdapter(), cfg)Next steps
- The final line prints
Saved: <path>. You can load it in a new process:
import mint
service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<paste Saved path>")- For sampling + tokenization details, see
/using-the-api/saving-and-loadingand/api-reference/sampling-client.