Multi-Turn RL
Multi-turn RL trains a model to interact with an environment across multiple steps. Unlike single-turn RL (one prompt, one response), the model generates responses, receives observations, and generates again — learning a policy over the full trajectory.
MinT uses mint.recipe for multi-turn RL. The core abstractions:
- MessageEnv — defines the environment at the message level
- EnvFromMessageEnv — bridges messages to tokens
- Prefix extension — merges multiple turns into one training Datum
- GRPO — group advantage normalization via
loss_fn="importance_sampling"
Architecture
┌──────────────────┐
│ MessageEnv │
│ (your logic) │
└────────┬─────────┘
│ messages
┌────────▼─────────┐
│ EnvFromMessageEnv │
│ (token bridge) │
└────────┬─────────┘
│ tokens + mask
┌────────────▼────────────┐
│ do_single_rollout() │
│ sample → step → repeat │
└────────────┬────────────┘
│ Trajectory
┌────────────▼────────────┐
│ trajectory_to_data() │
│ prefix extension + │
│ token masking │
└────────────┬────────────┘
│ Datum
┌────────────▼────────────┐
│ GRPO training step │
│ importance_sampling │
└─────────────────────────┘MessageEnv
Define your environment by subclassing MessageEnv. Two methods:
initial_observation()— returns the starting messages (system prompt + first user message)step(message)— receives the model's response, returns reward + next observation
from mint.recipe import Message, MessageEnv, MessageStepResult
class CalculatorEnv(MessageEnv):
def __init__(self, question: str, answer: float):
self.question = question
self.answer = answer
self.turns = 0
async def initial_observation(self) -> list[Message]:
return [
{"role": "system", "content": "You can use calc(expr) to evaluate math expressions. Give your final answer as: Answer: <number>"},
{"role": "user", "content": self.question},
]
async def step(self, message: Message) -> MessageStepResult:
content = _extract_content(message) # handle str or list
self.turns += 1
# Check for final answer
match = re.search(r"Answer:\s*([\d.]+)", content)
if match:
correct = abs(float(match.group(1)) - self.answer) < 0.01
return MessageStepResult(
reward=1.0 if correct else -0.5,
episode_done=True,
next_messages=[],
metrics={"correct": float(correct), "turns": self.turns},
)
# Check for tool call
calc_match = re.search(r"calc\((.+?)\)", content)
if calc_match:
result = _safe_calc(calc_match.group(1))
return MessageStepResult(
reward=0.0,
episode_done=False,
next_messages=[{"role": "user", "content": f"Result: {result}"}],
metrics={},
)
if self.turns >= 3:
return MessageStepResult(
reward=-1.0, episode_done=True, next_messages=[],
metrics={"timeout": 1.0},
)
return MessageStepResult(
reward=0.0, episode_done=True, next_messages=[],
metrics={"invalid_format": 1.0},
)Content type: Qwen3 models may return message["content"] as a list of content blocks instead of a string. Always handle both types in your step() method.
Prefix Extension
When a conversation has multiple turns, each turn's observation is a prefix of the next. trajectory_to_data() detects this and merges all turns into a single training Datum:
Turn 1: [sys prompt] [user question] [assistant response₁]
Turn 2: [sys prompt] [user question] [assistant response₁] [user followup] [assistant response₂]
└── prefix of turn 2 ──────────────┘
Merged Datum:
tokens: [sys] [user] [resp₁] [followup] [resp₂]
mask: [ 0 ] [ 0 ] [ 1 ] [ 0 ] [ 1 ]
└─ env ─┘ └ agent ┘ └─ env ─┘ └ agent ┘This gives O(T) compute instead of O(T^2) — each token appears once in the training datum, not once per turn.
Token Masking
The mask field marks which tokens receive gradients:
- Agent tokens (model-generated responses):
mask = 1.0— trained - Environment tokens (system prompt, user messages):
mask = 0.0— frozen
This ensures the model only learns from its own actions, not from copying environment text.
Dataset Wiring
Connect your MessageEnv to the training loop with three classes:
from dataclasses import dataclass
from collections.abc import Sequence
import chz
import mint.recipe as recipe
from mint.recipe import EnvFromMessageEnv, get_tokenizer
@dataclass(frozen=True)
class CalculatorEnvGroupBuilder(recipe.rl.types.EnvGroupBuilder):
question: str
answer: float
group_size: int
renderer_name: str
model_name: str
async def make_envs(self) -> Sequence[recipe.rl.types.Env]:
tok = get_tokenizer(self.model_name)
rend = recipe.renderers.get_renderer(self.renderer_name, tok)
return [
EnvFromMessageEnv(
renderer=rend,
message_env=CalculatorEnv(self.question, self.answer),
max_trajectory_tokens=2048,
max_generation_tokens=256,
)
for _ in range(self.group_size)
]
def logging_tags(self) -> list[str]:
return ["calculator"]
class CalculatorDataset(recipe.rl.types.RLDataset):
def __init__(self, problems, batch_size, group_size, renderer_name, model_name):
self.problems = problems
self.batch_size = batch_size
self.group_size = group_size
self.renderer_name = renderer_name
self.model_name = model_name
def __len__(self):
return max(1, len(self.problems) // self.batch_size)
def get_batch(self, batch_idx):
start = (batch_idx * self.batch_size) % len(self.problems)
batch = self.problems[start : start + self.batch_size]
return [
CalculatorEnvGroupBuilder(q, a, self.group_size, self.renderer_name, self.model_name)
for q, a in batch
]
@chz.chz
class CalculatorDatasetBuilder(recipe.rl.types.RLDatasetBuilder):
batch_size: int = 2
group_size: int = 4
renderer_name: str = ""
model_name: str = ""
async def __call__(self):
problems = [
("What is 12 * 34?", 408.0),
("What is 99 + 101?", 200.0),
("What is 256 / 8?", 32.0),
("What is 7 * 13?", 91.0),
]
return CalculatorDataset(
problems, self.batch_size, self.group_size,
self.renderer_name, self.model_name,
), NoneTraining
import asyncio
import mint.recipe as recipe
MODEL = "Qwen/Qwen3-0.6B"
renderer_name = recipe.get_recommended_renderer_name(MODEL)
config = recipe.rl.train.Config(
learning_rate=1e-5,
dataset_builder=CalculatorDatasetBuilder(
batch_size=2,
group_size=4,
renderer_name=renderer_name,
model_name=MODEL,
),
model_name=MODEL,
renderer_name=renderer_name,
lora_rank=16,
max_tokens=256,
temperature=0.8,
kl_penalty_coef=0.0,
log_path="/tmp/mint-multiturn-rl",
max_steps=3,
save_every=999,
eval_every=999,
)
asyncio.run(recipe.rl.train.main(config=config))View full source: recipes/multi_turn_rl.py
Verified Run
On Qwen/Qwen3-0.6B with LoRA rank 16, 4 problems, group_size=4, 1 training step:
| Metric | Value |
|---|---|
| Total episodes | 8 |
| Tokens per turn (action) | 256 |
| Tokens per turn (observation) | ~46 |
| Entropy | 0.31 |
| KL (sample vs train) | 0.010 |
| Training step time | 4.4s |
| Total time (incl. sampling + checkpoint) | ~100s |
The model hits max_tokens on every episode (Qwen3-0.6B uses extended thinking, consuming token budget on <think> blocks before reaching Answer:). With more training steps, the model learns to use calc() and produce shorter responses.
Under the Hood
trajectory_to_data() converts a Trajectory (list of Transitions) into a list of Datums. Each Datum's loss_fn_inputs contains:
| Field | Type | Purpose |
|---|---|---|
target_tokens | int list | Next-token prediction targets (left-shifted) |
logprobs | float list | Log-probabilities under the sampling policy (for importance sampling ratio) |
advantages | float list | GRPO advantages: reward - group_mean_reward, zero for environment tokens |
mask | float list | 1.0 for agent tokens, 0.0 for environment tokens |
The training loss is importance sampling (GRPO): for each agent token, the loss is weighted by the advantage and the ratio of current policy probability to sampling policy probability. Environment tokens (mask=0) contribute zero loss.
Import pattern: Use flat imports like from mint.recipe import MessageEnv, EnvFromMessageEnv, get_tokenizer for environment helpers. Use recipe.rl.train.main() for the high-level training entry point.