Mind Lab Toolkit (MinT)

Mini RL 训练(最小可运行版本)

这一页展示一个最小可运行的 RL 训练循环

  • 任务/环境:乘法(Question → Answer,范围 10-199)
  • 奖励:答对为 1.0,否则 0.0
  • Rollout:每题采样多条答案(group_size
  • 更新:构造 RL Datum,然后调用 forward_backward(..., loss_fn="importance_sampling") + optim_step()

前置条件

  • Python >= 3.11
  • 环境变量 MINT_API_KEY 中设置 MinT API key

安装(选择一种方式安装 MinT):

pip install -q python-dotenv
# 推荐(GitHub):
pip install -q git+https://github.com/MindLab-Research/mindlab-toolkit.git
# 或者如果已经在 pip 上:
pip install -q mindlab-toolkit

参数(环境变量)

环境变量默认值API 类型说明
MINT_BASE_MODELQwen/Qwen3-0.6B-基础模型名称
MINT_LORA_RANK8LoraConfigLoRA rank(最大 64)
MINT_RL_STEPS3-外层训练步数
MINT_RL_BATCH2-每步题目数量
MINT_RL_GROUP4-每题采样数量
MINT_RL_MAX_TOKENS64SamplingParams最大生成 token 数
MINT_RL_TEMPERATURE0.8SamplingParams采样随机性
MINT_RL_LR5e-5AdamParams学习率
MINT_SYSTEM_PROMPT_PATH--system prompt 文件路径
MINT_PRINT_TOKEN_COUNTS--设为 1 打印 token 计数

完整脚本

from __future__ import annotations

import os
import random
import re
import time
from pathlib import Path

from dotenv import load_dotenv

load_dotenv()

import mint
from mint import types

# Config via env vars
BASE_MODEL = os.environ.get("MINT_BASE_MODEL", "Qwen/Qwen3-0.6B")
LORA_RANK = int(os.environ.get("MINT_LORA_RANK", "8"))
RL_STEPS = int(os.environ.get("MINT_RL_STEPS", "3"))
BATCH_SIZE = int(os.environ.get("MINT_RL_BATCH", "2"))
GROUP_SIZE = int(os.environ.get("MINT_RL_GROUP", "4"))
MAX_TOKENS = int(os.environ.get("MINT_RL_MAX_TOKENS", "64"))
TEMPERATURE = float(os.environ.get("MINT_RL_TEMPERATURE", "0.8"))
LEARNING_RATE = float(os.environ.get("MINT_RL_LR", "5e-5"))
SYSTEM_PROMPT_PATH = os.environ.get("MINT_SYSTEM_PROMPT_PATH", "")
PRINT_TOKEN_COUNTS = os.environ.get("MINT_PRINT_TOKEN_COUNTS", "").lower() in ("1", "true")

# Load system prompt if provided
SYSTEM_PROMPT = ""
if SYSTEM_PROMPT_PATH and Path(SYSTEM_PROMPT_PATH).exists():
    SYSTEM_PROMPT = Path(SYSTEM_PROMPT_PATH).read_text().strip()


def extract_answer(response: str) -> str | None:
    match = re.search(r"-?\d+", response)
    return match.group(0) if match else None


def generate_problem() -> tuple[str, str]:
    a, b = random.randint(10, 199), random.randint(10, 199)
    return f"What is {a} * {b}?", str(a * b)


def compute_reward(response: str, correct: str) -> float:
    return 1.0 if extract_answer(response) == correct else 0.0


def build_prompt(question: str) -> str:
    if SYSTEM_PROMPT:
        return f"{SYSTEM_PROMPT}\n\nQuestion: {question}\nAnswer:"
    return f"Question: {question}\nAnswer:"


def validate_loss_inputs(inputs: dict) -> bool:
    required = ["target_tokens", "weights", "logprobs", "advantages"]
    if not all(k in inputs and len(inputs[k]) > 0 for k in required):
        return False
    lengths = [len(inputs[k]) for k in required]
    return len(set(lengths)) == 1


def save_weights_with_retry(training_client, service_client, base_model: str, name: str, retries: int = 3):
    for attempt in range(retries):
        try:
            path = training_client.save_weights_for_sampler(name=name).result().path
            return service_client.create_sampling_client(model_path=path, base_model=base_model)
        except Exception as e:
            if attempt < retries - 1:
                time.sleep(min(1.0 * (2 ** attempt), 10.0))
            else:
                raise RuntimeError(f"Failed after {retries} attempts: {e}") from e


def main() -> None:
    service_client = mint.ServiceClient()
    training_client = service_client.create_lora_training_client(
        base_model=BASE_MODEL, rank=LORA_RANK, train_mlp=True, train_attn=True, train_unembed=True
    )
    tokenizer = training_client.get_tokenizer()

    print(f"Config: model={BASE_MODEL}, steps={RL_STEPS}, batch={BATCH_SIZE}, group={GROUP_SIZE}, lr={LEARNING_RATE}")

    for step in range(RL_STEPS):
        sampling_client = save_weights_with_retry(training_client, service_client, BASE_MODEL, f"rl-step-{step}")
        training_datums: list[types.Datum] = []
        all_rewards: list[float] = []

        for _ in range(BATCH_SIZE):
            question, answer = generate_problem()
            prompt = build_prompt(question)
            prompt_tokens = tokenizer.encode(prompt)

            if PRINT_TOKEN_COUNTS:
                print(f"[step {step+1}] prompt_tokens={len(prompt_tokens)}")

            result = sampling_client.sample(
                prompt=types.ModelInput.from_ints(tokens=prompt_tokens),
                num_samples=GROUP_SIZE,
                sampling_params=types.SamplingParams(
                    max_tokens=MAX_TOKENS, temperature=TEMPERATURE, stop_token_ids=[tokenizer.eos_token_id]
                ),
            ).result()

            group_rewards, group_responses, group_logprobs = [], [], []
            for seq in result.sequences:
                response_text = tokenizer.decode(seq.tokens)
                reward = compute_reward(response_text, answer)
                group_rewards.append(reward)
                group_responses.append(list(seq.tokens))
                group_logprobs.append(list(seq.logprobs or [0.0] * len(seq.tokens)))

            all_rewards.extend(group_rewards)
            mean_reward = sum(group_rewards) / len(group_rewards)
            advantages = [r - mean_reward for r in group_rewards]

            if all(adv == 0.0 for adv in advantages):
                continue

            for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
                if not response_tokens:
                    continue
                full_tokens = prompt_tokens + response_tokens
                input_tokens, target_tokens = full_tokens[:-1], full_tokens[1:]
                prefix_len = len(prompt_tokens) - 1

                loss_inputs = {
                    "target_tokens": [int(t) for t in target_tokens],
                    "weights": [0.0] * prefix_len + [1.0] * len(response_tokens),
                    "logprobs": [0.0] * prefix_len + [float(lp) for lp in logprobs],
                    "advantages": [0.0] * prefix_len + [adv] * len(response_tokens),
                }

                if validate_loss_inputs(loss_inputs):
                    training_datums.append(
                        types.Datum(model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=loss_inputs)
                    )

        if training_datums:
            training_client.forward_backward(training_datums, loss_fn="importance_sampling").result()
            training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()

        avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
        accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
        print(f"Step {step+1}/{RL_STEPS}: accuracy={accuracy:.1%}, avg_reward={avg_reward:.3f}, datums={len(training_datums)}")

        # 每步保存 checkpoint(长时间训练推荐)
        if training_datums:
            step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
            print(f"  Checkpoint: {step_ckpt.path}")

    checkpoint = training_client.save_state(name="mini-rl-final").result()
    print(f"Saved checkpoint: {checkpoint.path}")


if __name__ == "__main__":
    main()

Checkpoint 与续训

长时间训练(尤其是 235B 等大模型)建议每步保存 checkpoint,以便从失败点恢复:

# 每步保存 checkpoint
step_ckpt = training_client.save_state(name=f"rl-step-{step}").result()
print(f"Checkpoint: {step_ckpt.path}")

# 从 checkpoint 续训,并保留 optimizer state:
training_client = service_client.create_lora_training_client(
    base_model=BASE_MODEL,
    rank=LORA_RANK,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)
training_client.load_state_with_optimizer(checkpoint_path).result()

常见坑

datums=0

如果 datums=0,意味着同一题的多条答案 reward 全一样,advantage 全为 0,代码会跳过更新。要获得学习信号,可以增加多样性:

  • 增大 MINT_RL_GROUP
  • 增大 MINT_RL_TEMPERATURE
  • 使用更细粒度的 reward(例如 partial credit)

save_weights_for_sampler 失败

脚本包含指数退避重试逻辑。如果 3 次尝试后仍然失败,检查:

  • 服务器连接
  • MinT 服务器磁盘空间

下一步:从 toy loop 到真实任务

当你准备把最小脚本扩展到更贴近真实任务的训练闭环,可以参考:RL 最佳实践

本页目录