Mind Lab Toolkit (MinT)
CustomizeRL

Multi-Turn RL

这个 recipe 展示多轮 RL:reward 信号来自对每一轮对话的评分。Model 学着在即时 reward(单轮质量)和长期 reward(整段对话成功率)之间做平衡。

Use Case

  • 对话 agent:训练聊天机器人维持连贯、有用的多轮对话。
  • 谈判:教 model 在多轮交互里有效地谈判。
  • 解题:训练 agent 通过迭代调整和反馈来解决问题。
  • 客服:在多轮里处理复杂问题、不丢上下文。

In Practice

import asyncio
import mint
from mint import types

async def multi_turn_rl():
    service_client = mint.ServiceClient()
    
    # 对话数据集:(user turn, assistant turn) 序列
    conversations = [
        {
            "turns": [
                {
                    "user": "What is the capital of France?",
                    "assistant": "The capital of France is Paris.",
                    "reward": 1.0,  # 答对
                },
                {
                    "user": "How many people live there?",
                    "assistant": "Paris has approximately 2.2 million residents in the city proper.",
                    "reward": 0.9,  # 好,但有近似
                },
            ]
        },
        {
            "turns": [
                {
                    "user": "Explain quantum mechanics.",
                    "assistant": "Quantum mechanics is the study of particles at atomic scales.",
                    "reward": 0.7,  # 太简化
                },
                {
                    "user": "Can you be more precise?",
                    "assistant": "Quantum mechanics uses wave functions and probability amplitudes to describe particle behavior.",
                    "reward": 0.95,  # 解释更好
                },
            ]
        },
    ]
    
    training_client = await service_client.create_lora_training_client_async(
        base_model="Qwen/Qwen3-0.6B",
        rank=16,
    )
    tokenizer = training_client.get_tokenizer()
    adam_params = types.AdamParams(learning_rate=1e-4)
    
    print("=== Multi-Turn RL Training ===")
    
    for epoch in range(2):
        epoch_losses = []
        
        for conversation in conversations:
            # 沿对话累积上下文
            context_messages = []
            
            for turn_idx, turn in enumerate(conversation["turns"]):
                # 加入 user 消息
                context_messages.append({"role": "user", "content": turn["user"]})
                
                # tokenize 上下文 (prompt)
                context_text = " ".join([m["content"] for m in context_messages])
                context_ids = tokenizer.encode(context_text)
                
                # assistant 回复 (completion)
                response_text = turn["assistant"]
                response_ids = tokenizer.encode(response_text)
                
                # 构造完整序列,再右移拆分成 input/target 对
                all_ids = context_ids + response_ids
                input_ids = all_ids[:-1]
                target_ids = all_ids[1:]
                
                model_input = types.ModelInput.from_ints(input_ids)
                
                # advantage:把每轮 reward 中心化
                turn_reward = turn["reward"]
                mean_reward = sum(t["reward"] for t in conversation["turns"]) / len(conversation["turns"])
                advantage = turn_reward - mean_reward
                
                # 所有数组长度必须和 model_input 一致
                n_total = len(input_ids)
                n_context = len(context_ids) - 1  # prompt token(不算梯度)
                n_response = n_total - n_context  # response token(算梯度)
                
                datum = types.Datum(
                    model_input=model_input,
                    loss_fn_inputs={
                        "target_tokens": target_ids,
                        "logprobs": [0.0] * n_context + [-0.5] * n_response,
                        "advantages": [0.0] * n_context + [advantage] * n_response,
                        "weights": [0.0] * n_context + [1.0] * n_response,
                    },
                )
                
                # 用 RL loss 做 forward-backward
                fb_future = await training_client.forward_backward_async(
                    [datum],
                    loss_fn="ppo",
                )
                result = await fb_future.result_async()
                epoch_losses.append(result.metrics.get("loss:mean", 0.0))
                
                # 把 assistant 回复也加入上下文,给下一轮用
                context_messages.append({"role": "assistant", "content": response_text})
        
        # 所有对话所有轮跑完后再 optimizer step
        optim_future = await training_client.optim_step_async(adam_params)
        await optim_future.result_async()
        
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        print(f"Epoch {epoch}: avg loss={avg_loss:.4f}")
    
    # 保存最终 model
    await training_client.save_weights_for_sampler_async(
        name="multi-turn-rl-v1"
    )
    print("Multi-turn RL model saved")

asyncio.run(multi_turn_rl())

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/multi_turn_rl.py

Verified Run

在 Qwen3-0.6B 上用 LoRA rank 16、2 段对话、PPO loss 训练:

  • Turn 1(13 tokens,7 个 response token):loss:mean=-0.043ratio:mean=1.04clipfrac=0.86
  • Turn 2(33 tokens,14 个 response token):loss:mean=0.048ratio:mean=0.59clipfrac=0.93
  • 硬件:远程 MinT 集群(mint-cn.macaron.xin)。
  • 上下文长度:对话越长算力成本越快上升(二次复杂度)。每段对话 max context 控制在 ~1024 token 比较经济。

本页目录