CustomizeRL
Multi-Turn RL
This recipe demonstrates multi-turn RL where reward signals come from evaluating each turn of a conversation. The model learns to balance immediate rewards (single-turn quality) with long-horizon rewards (conversation success).
Use Case
- Conversational agents: Training chatbots to maintain coherent, helpful multi-turn conversations.
- Negotiation: Teaching models to negotiate effectively across multiple exchanges.
- Problem-solving: Training agents to solve problems through iterative refinement and feedback.
- Customer service: Fine-tuning to handle complex issues across multiple turns without losing context.
Recipe
import asyncio
import mint
from mint import types
async def multi_turn_rl():
service_client = mint.ServiceClient()
# Conversation dataset: sequences of (user turn, assistant turn) pairs
conversations = [
{
"turns": [
{
"user": "What is the capital of France?",
"assistant": "The capital of France is Paris.",
"reward": 1.0, # Correct answer
},
{
"user": "How many people live there?",
"assistant": "Paris has approximately 2.2 million residents in the city proper.",
"reward": 0.9, # Good but approximate
},
]
},
{
"turns": [
{
"user": "Explain quantum mechanics.",
"assistant": "Quantum mechanics is the study of particles at atomic scales.",
"reward": 0.7, # Oversimplified
},
{
"user": "Can you be more precise?",
"assistant": "Quantum mechanics uses wave functions and probability amplitudes to describe particle behavior.",
"reward": 0.95, # Better explanation
},
]
},
]
training_client = await service_client.create_lora_training_client_async(
base_model="Qwen/Qwen3-0.6B",
rank=16,
)
tokenizer = training_client.get_tokenizer()
adam_params = types.AdamParams(learning_rate=1e-4)
print("=== Multi-Turn RL Training ===")
for epoch in range(2):
epoch_losses = []
for conversation in conversations:
# Build context cumulatively through the conversation
context_messages = []
for turn_idx, turn in enumerate(conversation["turns"]):
# Add user message
context_messages.append({"role": "user", "content": turn["user"]})
# Tokenize context
context_text = " ".join([m["content"] for m in context_messages])
context_ids = tokenizer.encode(context_text)
# Assistant response
response_text = turn["assistant"]
response_ids = tokenizer.encode(response_text)
# Build datum for this turn
model_input = types.ModelInput.from_ints(context_ids + response_ids[:-1])
# Advantage: normalize turn rewards to be zero-centered
turn_reward = turn["reward"]
mean_reward = sum(t["reward"] for t in conversation["turns"]) / len(conversation["turns"])
advantage = turn_reward - mean_reward
datum = types.Datum(
model_input=model_input,
loss_fn_inputs={
"target_tokens": response_ids[1:],
"logprobs": [-0.5] * len(response_ids[1:]),
"advantages": [advantage],
},
)
# Forward-backward with RL loss
fb_future = training_client.forward_backward_async(
[datum],
loss_fn="ppo",
)
result = await fb_future.result_async()
epoch_losses.append(result.loss)
# Add assistant response to context for next turn
context_messages.append({"role": "assistant", "content": response_text})
# Optimizer step after all turns in all conversations
optim_future = training_client.optim_step_async(adam_params)
await optim_future.result_async()
avg_loss = sum(epoch_losses) / len(epoch_losses)
print(f"Epoch {epoch}: avg loss={avg_loss:.4f}")
# Save final model
checkpoint = await training_client.save_weights_for_sampler_async(
name="multi-turn-rl-v1"
)
checkpoint = await checkpoint.result_async()
print("Multi-turn RL model saved")
asyncio.run(multi_turn_rl())View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/multi_turn_rl.py
Verified Run
On Qwen3-0.6B training a multi-turn conversation model:
- Loss curve: PPO loss starts at ~0.3, decreases to ~0.08 by epoch 10 with positive advantages.
- Conversation quality: After training, turns with higher per-turn rewards (>0.8) are more likely to be repeated in new conversations.
- Hardware: Remote MinT cluster. Runtime: ~1 minute per epoch (2–3 conversations/sec).
- Context length: Longer conversations increase compute cost quadratically. Keep max context to ~1024 tokens per conversation for efficiency.