DemoChat RL

Chat RL

This page documents demos/rl/adapters/preference_chat.py in mint-quickstart.

What this demo does

  • Runs RL-2 Preference Chat via the shared demos/rl/rl_core.py GRPO loop and a task-specific adapter.
  • Trains a LoRA over generic prompts using a proxy preference reward (length + structure + diversity).
  • Samples MINT_RL_GROUP completions per prompt, applies group-relative advantages, and updates via importance_sampling.

Expected output

  • Prints Model: ... and per-step Step N: avg_reward=... + datums=...; finishes with Saved: ....

Common gotchas

  • The reward is intentionally heuristic and gameable; for production alignment, replace compute_reward() with real preference judges or rule sets.
  • If rewards are flat (or datums=0), tune MINT_RL_TEMPERATURE/MINT_RL_GROUP/MINT_RL_BATCH and check prompt diversity.
  • Prompt formatting uses tokenizer.apply_chat_template when available, otherwise falls back to "User: ...\nAssistant:"; if your model is chat-template sensitive, adjust make_prompt().

Prerequisites

  • Python >= 3.11
  • MINT_API_KEY set (or configured via .env)

How to run

export MINT_API_KEY=sk-mint-...
python demos/rl/adapters/preference_chat.py

Parameters (env vars)

  • MINT_BASE_MODEL: default Qwen/Qwen3-0.6B
  • MINT_LORA_RANK: default 16
  • MINT_RL_STEPS: default 10
  • MINT_RL_BATCH: default 8
  • MINT_RL_GROUP: default 4
  • MINT_RL_LR: default 1e-4
  • MINT_RL_MAX_TOKENS: default 128
  • MINT_RL_TEMPERATURE: default 1.0

Full script

#!/usr/bin/env python3
"""RL-2 Preference Chat — adapter for rl_core.
 
Reward: proxy helpfulness score (length + structure + diversity).
Run:  python demos/rl/adapters/preference_chat.py
"""
 
from __future__ import annotations
 
import os
import random
import sys
from pathlib import Path
from typing import Any
 
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from rl_core import RLAdapter, RLConfig, run_grpo  # noqa: E402
 
PROMPTS = [
    "Explain what a variable is in programming.",
    "Write a short poem about the ocean.",
    "What are three benefits of exercise?",
    "Describe how to make a cup of tea.",
    "Why is the sky blue?",
    "Give tips for better sleep.",
    "What is machine learning?",
    "How do plants make food?",
]
 
 
class PreferenceChatAdapter(RLAdapter):
 
    def build_dataset(self) -> list[str]:
        return [random.choice(PROMPTS) for _ in range(50)]
 
    def make_prompt(self, sample: str, tokenizer: Any) -> list[int]:
        messages = [{"role": "user", "content": sample}]
        if hasattr(tokenizer, "apply_chat_template"):
            return list(tokenizer.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=True
            ))
        return tokenizer.encode(f"User: {sample}\nAssistant:")
 
    def compute_reward(self, response: str, sample: str) -> float:
        r = 0.0
        words = len(response.split())
        if 20 <= words <= 100:
            r += 0.4
        elif 10 <= words < 20 or 100 < words <= 150:
            r += 0.2
        if response.count(".") >= 2:
            r += 0.3
        unique_words = len(set(response.lower().split()))
        if words > 0 and unique_words > words * 0.5:
            r += 0.3
        return min(r, 1.0)
 
    def evaluate(self, step: int, rewards: list[float], num_datums: int) -> None:
        avg = sum(rewards) / len(rewards) if rewards else 0.0
        print(f"Step {step}: avg_reward={avg:.3f}, datums={num_datums}")
 
 
if __name__ == "__main__":
    cfg = RLConfig.from_env()
    cfg.steps = int(os.environ.get("MINT_RL_STEPS", "10"))
    cfg.batch = int(os.environ.get("MINT_RL_BATCH", "8"))
    cfg.max_tokens = int(os.environ.get("MINT_RL_MAX_TOKENS", "128"))
    run_grpo(PreferenceChatAdapter(), cfg)

Next steps

  • The final line prints Saved: <path>. You can load it in a new process:
import mint
 
service_client = mint.ServiceClient()
sampling_client = service_client.create_sampling_client(model_path="<paste Saved path>")
  • For sampling + tokenization details, see /using-the-api/saving-and-loading and /api-reference/sampling-client.