Mind Lab Toolkit (MinT)
CustomizeRL

Multi-Turn RL

This recipe demonstrates multi-turn RL where reward signals come from evaluating each turn of a conversation. The model learns to balance immediate rewards (single-turn quality) with long-horizon rewards (conversation success).

Use Case

  • Conversational agents: Training chatbots to maintain coherent, helpful multi-turn conversations.
  • Negotiation: Teaching models to negotiate effectively across multiple exchanges.
  • Problem-solving: Training agents to solve problems through iterative refinement and feedback.
  • Customer service: Fine-tuning to handle complex issues across multiple turns without losing context.

Recipe

import asyncio
import mint
from mint import types

async def multi_turn_rl():
    service_client = mint.ServiceClient()
    
    # Conversation dataset: sequences of (user turn, assistant turn) pairs
    conversations = [
        {
            "turns": [
                {
                    "user": "What is the capital of France?",
                    "assistant": "The capital of France is Paris.",
                    "reward": 1.0,  # Correct answer
                },
                {
                    "user": "How many people live there?",
                    "assistant": "Paris has approximately 2.2 million residents in the city proper.",
                    "reward": 0.9,  # Good but approximate
                },
            ]
        },
        {
            "turns": [
                {
                    "user": "Explain quantum mechanics.",
                    "assistant": "Quantum mechanics is the study of particles at atomic scales.",
                    "reward": 0.7,  # Oversimplified
                },
                {
                    "user": "Can you be more precise?",
                    "assistant": "Quantum mechanics uses wave functions and probability amplitudes to describe particle behavior.",
                    "reward": 0.95,  # Better explanation
                },
            ]
        },
    ]
    
    training_client = await service_client.create_lora_training_client_async(
        base_model="Qwen/Qwen3-0.6B",
        rank=16,
    )
    tokenizer = training_client.get_tokenizer()
    adam_params = types.AdamParams(learning_rate=1e-4)
    
    print("=== Multi-Turn RL Training ===")
    
    for epoch in range(2):
        epoch_losses = []
        
        for conversation in conversations:
            # Build context cumulatively through the conversation
            context_messages = []
            
            for turn_idx, turn in enumerate(conversation["turns"]):
                # Add user message
                context_messages.append({"role": "user", "content": turn["user"]})
                
                # Tokenize context
                context_text = " ".join([m["content"] for m in context_messages])
                context_ids = tokenizer.encode(context_text)
                
                # Assistant response
                response_text = turn["assistant"]
                response_ids = tokenizer.encode(response_text)
                
                # Build datum for this turn
                model_input = types.ModelInput.from_ints(context_ids + response_ids[:-1])
                
                # Advantage: normalize turn rewards to be zero-centered
                turn_reward = turn["reward"]
                mean_reward = sum(t["reward"] for t in conversation["turns"]) / len(conversation["turns"])
                advantage = turn_reward - mean_reward
                
                datum = types.Datum(
                    model_input=model_input,
                    loss_fn_inputs={
                        "target_tokens": response_ids[1:],
                        "logprobs": [-0.5] * len(response_ids[1:]),
                        "advantages": [advantage],
                    },
                )
                
                # Forward-backward with RL loss
                fb_future = training_client.forward_backward_async(
                    [datum],
                    loss_fn="ppo",
                )
                result = await fb_future.result_async()
                epoch_losses.append(result.loss)
                
                # Add assistant response to context for next turn
                context_messages.append({"role": "assistant", "content": response_text})
        
        # Optimizer step after all turns in all conversations
        optim_future = training_client.optim_step_async(adam_params)
        await optim_future.result_async()
        
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        print(f"Epoch {epoch}: avg loss={avg_loss:.4f}")
    
    # Save final model
    checkpoint = await training_client.save_weights_for_sampler_async(
        name="multi-turn-rl-v1"
    )
    checkpoint = await checkpoint.result_async()
    print("Multi-turn RL model saved")

asyncio.run(multi_turn_rl())

View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/multi_turn_rl.py

Verified Run

On Qwen3-0.6B training a multi-turn conversation model:

  • Loss curve: PPO loss starts at ~0.3, decreases to ~0.08 by epoch 10 with positive advantages.
  • Conversation quality: After training, turns with higher per-turn rewards (>0.8) are more likely to be repeated in new conversations.
  • Hardware: Remote MinT cluster. Runtime: ~1 minute per epoch (2–3 conversations/sec).
  • Context length: Longer conversations increase compute cost quadratically. Keep max context to ~1024 tokens per conversation for efficiency.

On this page