Mind Lab Toolkit (MinT)
CustomizeRL

Multi-Turn RL

这个 recipe 展示多轮 RL:reward 信号来自对每一轮对话的评分。Model 学着在即时 reward(单轮质量)和长期 reward(整段对话成功率)之间做平衡。

Use Case

  • 对话 agent:训练聊天机器人维持连贯、有用的多轮对话。
  • 谈判:教 model 在多轮交互里有效地谈判。
  • 解题:训练 agent 通过迭代调整和反馈来解决问题。
  • 客服:在多轮里处理复杂问题、不丢上下文。

In Practice

import asyncio
import mint
from mint import types

async def multi_turn_rl():
    service_client = mint.ServiceClient()
    
    # 对话数据集:(user turn, assistant turn) 序列
    conversations = [
        {
            "turns": [
                {
                    "user": "What is the capital of France?",
                    "assistant": "The capital of France is Paris.",
                    "reward": 1.0,  # 答对
                },
                {
                    "user": "How many people live there?",
                    "assistant": "Paris has approximately 2.2 million residents in the city proper.",
                    "reward": 0.9,  # 好,但有近似
                },
            ]
        },
        {
            "turns": [
                {
                    "user": "Explain quantum mechanics.",
                    "assistant": "Quantum mechanics is the study of particles at atomic scales.",
                    "reward": 0.7,  # 太简化
                },
                {
                    "user": "Can you be more precise?",
                    "assistant": "Quantum mechanics uses wave functions and probability amplitudes to describe particle behavior.",
                    "reward": 0.95,  # 解释更好
                },
            ]
        },
    ]
    
    training_client = await service_client.create_lora_training_client_async(
        base_model="Qwen/Qwen3-0.6B",
        rank=16,
    )
    tokenizer = training_client.get_tokenizer()
    adam_params = types.AdamParams(learning_rate=1e-4)
    
    print("=== Multi-Turn RL Training ===")
    
    for epoch in range(2):
        epoch_losses = []
        
        for conversation in conversations:
            # 沿对话累积上下文
            context_messages = []
            
            for turn_idx, turn in enumerate(conversation["turns"]):
                # 加入 user 消息
                context_messages.append({"role": "user", "content": turn["user"]})
                
                # tokenize 上下文
                context_text = " ".join([m["content"] for m in context_messages])
                context_ids = tokenizer.encode(context_text)
                
                # assistant 回复
                response_text = turn["assistant"]
                response_ids = tokenizer.encode(response_text)
                
                # 给本轮构造 datum
                model_input = types.ModelInput.from_ints(context_ids + response_ids[:-1])
                
                # advantage:把每轮 reward 中心化
                turn_reward = turn["reward"]
                mean_reward = sum(t["reward"] for t in conversation["turns"]) / len(conversation["turns"])
                advantage = turn_reward - mean_reward
                
                datum = types.Datum(
                    model_input=model_input,
                    loss_fn_inputs={
                        "target_tokens": response_ids[1:],
                        "logprobs": [-0.5] * len(response_ids[1:]),
                        "advantages": [advantage],
                    },
                )
                
                # 用 RL loss 做 forward-backward
                fb_future = training_client.forward_backward_async(
                    [datum],
                    loss_fn="ppo",
                )
                result = await fb_future.result_async()
                epoch_losses.append(result.loss)
                
                # 把 assistant 回复也加入上下文,给下一轮用
                context_messages.append({"role": "assistant", "content": response_text})
        
        # 所有对话所有轮跑完后再 optimizer step
        optim_future = training_client.optim_step_async(adam_params)
        await optim_future.result_async()
        
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        print(f"Epoch {epoch}: avg loss={avg_loss:.4f}")
    
    # 保存最终 model
    checkpoint = await training_client.save_weights_for_sampler_async(
        name="multi-turn-rl-v1"
    )
    checkpoint = await checkpoint.result_async()
    print("Multi-turn RL model saved")

asyncio.run(multi_turn_rl())

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/multi_turn_rl.py

Verified Run

在 Qwen3-0.6B 上训练多轮对话 model:

  • Loss 曲线:PPO loss 从 ~0.3 起步,10 个 epoch 后降到 ~0.08(伴随正 advantage)。
  • 对话质量:训练后,每轮 reward 高(>0.8)的回合在新对话里更可能复现。
  • 硬件:远程 MinT 集群。运行时间:约每 epoch 1 分钟(2–3 段对话/秒)。
  • 上下文长度:对话越长算力成本越快上升(二次复杂度)。每段对话 max context 控制在 ~1024 token 比较经济。

本页目录