Mini RL 训练(最小可运行版本)
这一页展示一个最小可运行的 RL 训练循环:
- 任务/环境:乘法(
Question → Answer,范围 10-199) - 奖励:答对为
1.0,否则0.0 - Rollout:每题采样多条答案(
group_size) - 更新:构造 RL
Datum,然后调用forward_backward(..., loss_fn="importance_sampling")+optim_step()
前置条件
- Python >= 3.11
- 环境变量
MINT_API_KEY中设置 MinT API key
安装(选择一种方式安装 MinT):
pip install -q python-dotenv
# 推荐(GitHub):
pip install -q git+https://github.com/MindLab-Research/mindlab-toolkit.git
# 或者如果已经在 pip 上:
pip install -q mindlab-toolkit参数(环境变量)
| 环境变量 | 默认值 | API 类型 | 说明 |
|---|---|---|---|
MINT_BASE_MODEL | Qwen/Qwen3-0.6B | - | 基础模型名称 |
MINT_LORA_RANK | 8 | LoraConfig | LoRA rank(最大 64) |
MINT_RL_STEPS | 3 | - | 外层训练步数 |
MINT_RL_BATCH | 2 | - | 每步题目数量 |
MINT_RL_GROUP | 4 | - | 每题采样数量 |
MINT_RL_MAX_TOKENS | 64 | SamplingParams | 最大生成 token 数 |
MINT_RL_TEMPERATURE | 0.8 | SamplingParams | 采样随机性 |
MINT_RL_LR | 5e-5 | AdamParams | 学习率 |
MINT_SYSTEM_PROMPT_PATH | - | - | system prompt 文件路径 |
MINT_PRINT_TOKEN_COUNTS | - | - | 设为 1 打印 token 计数 |
完整脚本
from __future__ import annotations
import os
import random
import re
import time
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
import mint
from mint import types
# Config via env vars
BASE_MODEL = os.environ.get("MINT_BASE_MODEL", "Qwen/Qwen3-0.6B")
LORA_RANK = int(os.environ.get("MINT_LORA_RANK", "8"))
RL_STEPS = int(os.environ.get("MINT_RL_STEPS", "3"))
BATCH_SIZE = int(os.environ.get("MINT_RL_BATCH", "2"))
GROUP_SIZE = int(os.environ.get("MINT_RL_GROUP", "4"))
MAX_TOKENS = int(os.environ.get("MINT_RL_MAX_TOKENS", "64"))
TEMPERATURE = float(os.environ.get("MINT_RL_TEMPERATURE", "0.8"))
LEARNING_RATE = float(os.environ.get("MINT_RL_LR", "5e-5"))
SYSTEM_PROMPT_PATH = os.environ.get("MINT_SYSTEM_PROMPT_PATH", "")
PRINT_TOKEN_COUNTS = os.environ.get("MINT_PRINT_TOKEN_COUNTS", "").lower() in ("1", "true")
# Load system prompt if provided
SYSTEM_PROMPT = ""
if SYSTEM_PROMPT_PATH and Path(SYSTEM_PROMPT_PATH).exists():
SYSTEM_PROMPT = Path(SYSTEM_PROMPT_PATH).read_text().strip()
def extract_answer(response: str) -> str | None:
match = re.search(r"-?\d+", response)
return match.group(0) if match else None
def generate_problem() -> tuple[str, str]:
a, b = random.randint(10, 199), random.randint(10, 199)
return f"What is {a} * {b}?", str(a * b)
def compute_reward(response: str, correct: str) -> float:
return 1.0 if extract_answer(response) == correct else 0.0
def build_prompt(question: str) -> str:
if SYSTEM_PROMPT:
return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\nAnswer:"
return f"Question: {question}\nAnswer:"
def validate_loss_inputs(inputs: dict) -> bool:
required = ["target_tokens", "weights", "logprobs", "advantages"]
if not all(k in inputs and len(inputs[k]) > 0 for k in required):
return False
lengths = [len(inputs[k]) for k in required]
return len(set(lengths)) == 1
def save_weights_with_retry(training_client, service_client, base_model: str, name: str, retries: int = 3):
for attempt in range(retries):
try:
path = training_client.save_weights_for_sampler(name=name).result().path
return service_client.create_sampling_client(model_path=path, base_model=base_model)
except Exception as e:
if attempt < retries - 1:
time.sleep(min(1.0 * (2 ** attempt), 10.0))
else:
raise RuntimeError(f"Failed after {retries} attempts: {e}") from e
def main() -> None:
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL, rank=LORA_RANK, train_mlp=True, train_attn=True, train_unembed=True
)
tokenizer = training_client.get_tokenizer()
print(f"Config: model={BASE_MODEL}, steps={RL_STEPS}, batch={BATCH_SIZE}, group={GROUP_SIZE}, lr={LEARNING_RATE}")
for step in range(RL_STEPS):
sampling_client = save_weights_with_retry(training_client, service_client, BASE_MODEL, f"rl-step-{step}")
training_datums: list[types.Datum] = []
all_rewards: list[float] = []
for _ in range(BATCH_SIZE):
question, answer = generate_problem()
prompt = build_prompt(question)
prompt_tokens = tokenizer.encode(prompt)
if PRINT_TOKEN_COUNTS:
print(f"[step {step+1}] prompt_tokens={len(prompt_tokens)}")
result = sampling_client.sample(
prompt=types.ModelInput.from_ints(tokens=prompt_tokens),
num_samples=GROUP_SIZE,
sampling_params=types.SamplingParams(
max_tokens=MAX_TOKENS, temperature=TEMPERATURE, stop_token_ids=[tokenizer.eos_token_id]
),
).result()
group_rewards, group_responses, group_logprobs = [], [], []
for seq in result.sequences:
response_text = tokenizer.decode(seq.tokens)
reward = compute_reward(response_text, answer)
group_rewards.append(reward)
group_responses.append(list(seq.tokens))
group_logprobs.append(list(seq.logprobs or [0.0] * len(seq.tokens)))
all_rewards.extend(group_rewards)
mean_reward = sum(group_rewards) / len(group_rewards)
advantages = [r - mean_reward for r in group_rewards]
if all(adv == 0.0 for adv in advantages):
continue
for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
if not response_tokens:
continue
full_tokens = prompt_tokens + response_tokens
input_tokens, target_tokens = full_tokens[:-1], full_tokens[1:]
prefix_len = len(prompt_tokens) - 1
loss_inputs = {
"target_tokens": [int(t) for t in target_tokens],
"weights": [0.0] * prefix_len + [1.0] * len(response_tokens),
"logprobs": [0.0] * prefix_len + [float(lp) for lp in logprobs],
"advantages": [0.0] * prefix_len + [adv] * len(response_tokens),
}
if validate_loss_inputs(loss_inputs):
training_datums.append(
types.Datum(model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=loss_inputs)
)
if training_datums:
training_client.forward_backward(training_datums, loss_fn="importance_sampling").result()
training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()
avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
print(f"Step {step+1}/{RL_STEPS}: accuracy={accuracy:.1%}, avg_reward={avg_reward:.3f}, datums={len(training_datums)}")
# 每步保存 checkpoint(长时间训练推荐)
if training_datums:
step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
print(f" Checkpoint: {step_ckpt.path}")
checkpoint = training_client.save_state(name="mini-rl-final").result()
print(f"Saved checkpoint: {checkpoint.path}")
if __name__ == "__main__":
main()Checkpoint 与续训
长时间训练(尤其是 235B 等大模型)建议每步保存 checkpoint,以便从失败点恢复:
# 每步保存 checkpoint
step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
print(f"Checkpoint: {step_ckpt.path}")
# 从 checkpoint 续训,并保留 optimizer state:
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL,
rank=LORA_RANK,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
training_client.load_state_with_optimizer(checkpoint_path).result()常见坑
datums=0
如果 datums=0,意味着同一题的多条答案 reward 全一样,advantage 全为 0,代码会跳过更新。要获得学习信号,可以增加多样性:
- 增大
MINT_RL_GROUP - 增大
MINT_RL_TEMPERATURE - 使用更细粒度的 reward(例如 partial credit)
save_weights_for_sampler 失败
脚本包含指数退避重试逻辑。如果 3 次尝试后仍然失败,检查:
- 服务器连接
- MinT 服务器磁盘空间
下一步:从 toy loop 到真实任务
当你准备把最小脚本扩展到更贴近真实任务的训练闭环,可以参考:RL 最佳实践