Chat RL
这一页对应 mint-quickstart 中的 demos/rl/adapters/preference_chat.py。
这个 demo 做什么
- 这是 RL-2 Preference Chat:通过共享
demos/rl/rl_core.py循环 + 任务 adapter 运行训练。 - 对通用对话 prompt 做 RL,使用偏好代理 reward(长度/结构/词汇多样性)进行打分。
- 每个 prompt 采样
MINT_RL_GROUP条回答,做 group-relative advantage,并用importance_sampling更新 LoRA。
期望输出
- 会先打印
Model: ...,每一步打印Step N: avg_reward=...、datums=...,最后输出Saved: ...。
常见坑
- 这是启发式 reward,容易被刷分;生产场景请替换
compute_reward()(偏好模型或规则系统)。 - 如果 reward 曲线很平(或
datums=0),调MINT_RL_TEMPERATURE/MINT_RL_GROUP/MINT_RL_BATCH,并检查 prompt 多样性。 - prompt 构造优先走
tokenizer.apply_chat_template,否则退化为"User: ...\nAssistant:";若模型对模板敏感,请改make_prompt()。
前置条件
- Python >= 3.11
- 设置
MINT_API_KEY(或用.env配置)
运行方式
export MINT_API_KEY=sk-...
python demos/rl/adapters/preference_chat.py参数(环境变量)
MINT_BASE_MODEL:默认Qwen/Qwen3-0.6BMINT_LORA_RANK:默认16MINT_RL_STEPS:默认10MINT_RL_BATCH:默认8MINT_RL_GROUP:默认4MINT_RL_LR:默认1e-4MINT_RL_MAX_TOKENS:默认128MINT_RL_TEMPERATURE:默认1.0
完整脚本
#!/usr/bin/env python3
"""RL-2 Preference Chat — adapter for rl_core.
Reward: proxy helpfulness score (length + structure + diversity).
Run: python demos/rl/adapters/preference_chat.py
"""
from __future__ import annotations
import os
import random
import sys
from collections.abc import Mapping
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo # noqa: E402
PROMPTS = [
"Explain what a variable is in programming.",
"Write a short poem about the ocean.",
"What are three benefits of exercise?",
"Describe how to make a cup of tea.",
"Why is the sky blue?",
"Give tips for better sleep.",
"What is machine learning?",
"How do plants make food?",
]
def _coerce_chat_template_tokens(tokenized: Any) -> list[int]:
if isinstance(tokenized, Mapping):
if "input_ids" not in tokenized:
raise TypeError("chat template output is missing `input_ids`")
tokenized = tokenized["input_ids"]
elif hasattr(tokenized, "input_ids"):
tokenized = getattr(tokenized, "input_ids")
if hasattr(tokenized, "tolist"):
tokenized = tokenized.tolist()
if isinstance(tokenized, tuple):
tokenized = list(tokenized)
if not isinstance(tokenized, list):
raise TypeError(
"chat template output must be a token list or mapping with `input_ids`"
)
if tokenized and isinstance(tokenized[0], list):
tokenized = tokenized[0]
return [int(token) for token in tokenized]
class PreferenceChatAdapter(RLAdapter):
def build_dataset(self) -> list[str]:
return [random.choice(PROMPTS) for _ in range(50)]
def make_prompt(self, sample: str, tokenizer: Any) -> list[int]:
messages = [{"role": "user", "content": sample}]
if hasattr(tokenizer, "apply_chat_template"):
return _coerce_chat_template_tokens(
tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
)
return tokenizer.encode(f"User: {sample}\nAssistant:")
def compute_reward(self, response: str, sample: str) -> float:
r = 0.0
words = len(response.split())
if 20 <= words <= 100:
r += 0.4
elif 10 <= words < 20 or 100 < words <= 150:
r += 0.2
if response.count(".") >= 2:
r += 0.3
unique_words = len(set(response.lower().split()))
if words > 0 and unique_words > words * 0.5:
r += 0.3
return min(r, 1.0)
def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
avg = sum(rewards) / len(rewards) if rewards else 0.0
print(f"Step {step}: avg_reward={avg:.3f}, datums={num_datums}")
if __name__ == "__main__":
cfg = RLConfig.from_env()
cfg.steps = int(os.environ.get("MINT_RL_STEPS", "10"))
cfg.batch = int(os.environ.get("MINT_RL_BATCH", "8"))
cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "128"))
run_grpo(PreferenceChatAdapter(), cfg)下一步
- 最后会打印
Saved: <path>;把这个 path 当成model_path用于创建 sampling client:
import mint
service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<粘贴 Saved path>")- 采样与 tokenizer 相关说明见:
/using-the-api/saving-and-loading、/api-reference/sampling-client。