Mind Lab Toolkit (MinT)

Chat RL

这一页对应 mint-quickstart 中的 demos/rl/adapters/preference_chat.py

这个 demo 做什么

  • 这是 RL-2 Preference Chat:通过共享 demos/rl/rl_core.py 循环 + 任务 adapter 运行训练。
  • 对通用对话 prompt 做 RL,使用偏好代理 reward(长度/结构/词汇多样性)进行打分。
  • 每个 prompt 采样 MINT_RL_GROUP 条回答,做 group-relative advantage,并用 importance_sampling 更新 LoRA。

期望输出

  • 会先打印 Model: ...,每一步打印 Step N: avg_reward=...datums=...,最后输出 Saved: ...

常见坑

  • 这是启发式 reward,容易被刷分;生产场景请替换 compute_reward()(偏好模型或规则系统)。
  • 如果 reward 曲线很平(或 datums=0),调 MINT_RL_TEMPERATURE/MINT_RL_GROUP/MINT_RL_BATCH,并检查 prompt 多样性。
  • prompt 构造优先走 tokenizer.apply_chat_template,否则退化为 "User: ...\nAssistant:";若模型对模板敏感,请改 make_prompt()

前置条件

  • Python >= 3.11
  • 设置 MINT_API_KEY(或用 .env 配置)

运行方式

export MINT_API_KEY=sk-...
python demos/rl/adapters/preference_chat.py

参数(环境变量)

  • MINT_BASE_MODEL:默认 Qwen/Qwen3-0.6B
  • MINT_LORA_RANK:默认 16
  • MINT_RL_STEPS:默认 10
  • MINT_RL_BATCH:默认 8
  • MINT_RL_GROUP:默认 4
  • MINT_RL_LR:默认 1e-4
  • MINT_RL_MAX_TOKENS:默认 128
  • MINT_RL_TEMPERATURE:默认 1.0

完整脚本

#!/usr/bin/env python3
"""RL-2 Preference Chat — adapter for rl_core.

Reward: proxy helpfulness score (length + structure + diversity).
Run:  python demos/rl/adapters/preference_chat.py
"""

from __future__ import annotations

import os
import random
import sys
from collections.abc import Mapping
from pathlib import Path
from typing import Any

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo  # noqa: E402

PROMPTS = [
    "Explain what a variable is in programming.",
    "Write a short poem about the ocean.",
    "What are three benefits of exercise?",
    "Describe how to make a cup of tea.",
    "Why is the sky blue?",
    "Give tips for better sleep.",
    "What is machine learning?",
    "How do plants make food?",
]


def _coerce_chat_template_tokens(tokenized: Any) -> list[int]:
    if isinstance(tokenized, Mapping):
        if "input_ids" not in tokenized:
            raise TypeError("chat template output is missing `input_ids`")
        tokenized = tokenized["input_ids"]
    elif hasattr(tokenized, "input_ids"):
        tokenized = getattr(tokenized, "input_ids")
    if hasattr(tokenized, "tolist"):
        tokenized = tokenized.tolist()
    if isinstance(tokenized, tuple):
        tokenized = list(tokenized)
    if not isinstance(tokenized, list):
        raise TypeError(
            "chat template output must be a token list or mapping with `input_ids`"
        )
    if tokenized and isinstance(tokenized[0], list):
        tokenized = tokenized[0]
    return [int(token) for token in tokenized]


class PreferenceChatAdapter(RLAdapter):

    def build_dataset(self) -> list[str]:
        return [random.choice(PROMPTS) for _ in range(50)]

    def make_prompt(self, sample: str, tokenizer: Any) -> list[int]:
        messages = [{"role": "user", "content": sample}]
        if hasattr(tokenizer, "apply_chat_template"):
            return _coerce_chat_template_tokens(
                tokenizer.apply_chat_template(
                    messages, tokenize=True, add_generation_prompt=True
                )
            )
        return tokenizer.encode(f"User: {sample}\nAssistant:")

    def compute_reward(self, response: str, sample: str) -> float:
        r = 0.0
        words = len(response.split())
        if 20 <= words <= 100:
            r += 0.4
        elif 10 <= words < 20 or 100 < words <= 150:
            r += 0.2
        if response.count(".") >= 2:
            r += 0.3
        unique_words = len(set(response.lower().split()))
        if words > 0 and unique_words > words * 0.5:
            r += 0.3
        return min(r, 1.0)

    def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
        avg = sum(rewards) / len(rewards) if rewards else 0.0
        print(f"Step {step}: avg_reward={avg:.3f}, datums={num_datums}")


if __name__ == "__main__":
    cfg = RLConfig.from_env()
    cfg.steps = int(os.environ.get("MINT_RL_STEPS", "10"))
    cfg.batch = int(os.environ.get("MINT_RL_BATCH", "8"))
    cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "128"))
    run_grpo(PreferenceChatAdapter(), cfg)

下一步

  • 最后会打印 Saved: <path>;把这个 path 当成 model_path 用于创建 sampling client:
import mint

service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<粘贴 Saved path>")
  • 采样与 tokenizer 相关说明见:/using-the-api/saving-and-loading/api-reference/sampling-client

本页目录