Mind Lab Toolkit (MinT)
CustomizeRL

Multi-Turn RL

Multi-turn RL trains a model to interact with an environment across multiple steps. Unlike single-turn RL (one prompt, one response), the model generates responses, receives observations, and generates again — learning a policy over the full trajectory.

MinT uses mint.recipe for multi-turn RL. The core abstractions:

  1. MessageEnv — defines the environment at the message level
  2. EnvFromMessageEnv — bridges messages to tokens
  3. Prefix extension — merges multiple turns into one training Datum
  4. GRPO — group advantage normalization via loss_fn="importance_sampling"

Architecture

                   ┌──────────────────┐
                   │   MessageEnv     │
                   │  (your logic)    │
                   └────────┬─────────┘
                            │ messages
                   ┌────────▼─────────┐
                   │ EnvFromMessageEnv │
                   │  (token bridge)  │
                   └────────┬─────────┘
                            │ tokens + mask
               ┌────────────▼────────────┐
               │   do_single_rollout()   │
               │  sample → step → repeat │
               └────────────┬────────────┘
                            │ Trajectory
               ┌────────────▼────────────┐
               │  trajectory_to_data()   │
               │  prefix extension +     │
               │  token masking          │
               └────────────┬────────────┘
                            │ Datum
               ┌────────────▼────────────┐
               │   GRPO training step    │
               │  importance_sampling    │
               └─────────────────────────┘

MessageEnv

Define your environment by subclassing MessageEnv. Two methods:

  • initial_observation() — returns the starting messages (system prompt + first user message)
  • step(message) — receives the model's response, returns reward + next observation
from mint.recipe import Message, MessageEnv, MessageStepResult

class CalculatorEnv(MessageEnv):
    def __init__(self, question: str, answer: float):
        self.question = question
        self.answer = answer
        self.turns = 0

    async def initial_observation(self) -> list[Message]:
        return [
            {"role": "system", "content": "You can use calc(expr) to evaluate math expressions. Give your final answer as: Answer: <number>"},
            {"role": "user", "content": self.question},
        ]

    async def step(self, message: Message) -> MessageStepResult:
        content = _extract_content(message)  # handle str or list
        self.turns += 1

        # Check for final answer
        match = re.search(r"Answer:\s*([\d.]+)", content)
        if match:
            correct = abs(float(match.group(1)) - self.answer) < 0.01
            return MessageStepResult(
                reward=1.0 if correct else -0.5,
                episode_done=True,
                next_messages=[],
                metrics={"correct": float(correct), "turns": self.turns},
            )

        # Check for tool call
        calc_match = re.search(r"calc\((.+?)\)", content)
        if calc_match:
            result = _safe_calc(calc_match.group(1))
            return MessageStepResult(
                reward=0.0,
                episode_done=False,
                next_messages=[{"role": "user", "content": f"Result: {result}"}],
                metrics={},
            )

        if self.turns >= 3:
            return MessageStepResult(
                reward=-1.0, episode_done=True, next_messages=[],
                metrics={"timeout": 1.0},
            )

        return MessageStepResult(
            reward=0.0, episode_done=True, next_messages=[],
            metrics={"invalid_format": 1.0},
        )

Content type: Qwen3 models may return message["content"] as a list of content blocks instead of a string. Always handle both types in your step() method.

Prefix Extension

When a conversation has multiple turns, each turn's observation is a prefix of the next. trajectory_to_data() detects this and merges all turns into a single training Datum:

Turn 1: [sys prompt] [user question] [assistant response₁]
Turn 2: [sys prompt] [user question] [assistant response₁] [user followup] [assistant response₂]
                      └── prefix of turn 2 ──────────────┘

Merged Datum:
  tokens:  [sys] [user] [resp₁] [followup] [resp₂]
  mask:    [ 0 ] [ 0  ] [  1  ] [   0    ] [  1  ]
             └─ env ─┘   └ agent ┘  └─ env ─┘  └ agent ┘

This gives O(T) compute instead of O(T^2) — each token appears once in the training datum, not once per turn.

Token Masking

The mask field marks which tokens receive gradients:

  • Agent tokens (model-generated responses): mask = 1.0 — trained
  • Environment tokens (system prompt, user messages): mask = 0.0 — frozen

This ensures the model only learns from its own actions, not from copying environment text.

Dataset Wiring

Connect your MessageEnv to the training loop with three classes:

from dataclasses import dataclass
from collections.abc import Sequence
import chz
import mint.recipe as recipe
from mint.recipe import EnvFromMessageEnv, get_tokenizer

@dataclass(frozen=True)
class CalculatorEnvGroupBuilder(recipe.rl.types.EnvGroupBuilder):
    question: str
    answer: float
    group_size: int
    renderer_name: str
    model_name: str

    async def make_envs(self) -> Sequence[recipe.rl.types.Env]:
        tok = get_tokenizer(self.model_name)
        rend = recipe.renderers.get_renderer(self.renderer_name, tok)
        return [
            EnvFromMessageEnv(
                renderer=rend,
                message_env=CalculatorEnv(self.question, self.answer),
                max_trajectory_tokens=2048,
                max_generation_tokens=256,
            )
            for _ in range(self.group_size)
        ]

    def logging_tags(self) -> list[str]:
        return ["calculator"]

class CalculatorDataset(recipe.rl.types.RLDataset):
    def __init__(self, problems, batch_size, group_size, renderer_name, model_name):
        self.problems = problems
        self.batch_size = batch_size
        self.group_size = group_size
        self.renderer_name = renderer_name
        self.model_name = model_name

    def __len__(self):
        return max(1, len(self.problems) // self.batch_size)

    def get_batch(self, batch_idx):
        start = (batch_idx * self.batch_size) % len(self.problems)
        batch = self.problems[start : start + self.batch_size]
        return [
            CalculatorEnvGroupBuilder(q, a, self.group_size, self.renderer_name, self.model_name)
            for q, a in batch
        ]

@chz.chz
class CalculatorDatasetBuilder(recipe.rl.types.RLDatasetBuilder):
    batch_size: int = 2
    group_size: int = 4
    renderer_name: str = ""
    model_name: str = ""

    async def __call__(self):
        problems = [
            ("What is 12 * 34?", 408.0),
            ("What is 99 + 101?", 200.0),
            ("What is 256 / 8?", 32.0),
            ("What is 7 * 13?", 91.0),
        ]
        return CalculatorDataset(
            problems, self.batch_size, self.group_size,
            self.renderer_name, self.model_name,
        ), None

Training

import asyncio
import mint.recipe as recipe

MODEL = "Qwen/Qwen3-0.6B"
renderer_name = recipe.get_recommended_renderer_name(MODEL)

config = recipe.rl.train.Config(
    learning_rate=1e-5,
    dataset_builder=CalculatorDatasetBuilder(
        batch_size=2,
        group_size=4,
        renderer_name=renderer_name,
        model_name=MODEL,
    ),
    model_name=MODEL,
    renderer_name=renderer_name,
    lora_rank=16,
    max_tokens=256,
    temperature=0.8,
    kl_penalty_coef=0.0,
    log_path="/tmp/mint-multiturn-rl",
    max_steps=3,
    save_every=999,
    eval_every=999,
)

asyncio.run(recipe.rl.train.main(config=config))

View full source: recipes/multi_turn_rl.py

Verified Run

On Qwen/Qwen3-0.6B with LoRA rank 16, 4 problems, group_size=4, 1 training step:

MetricValue
Total episodes8
Tokens per turn (action)256
Tokens per turn (observation)~46
Entropy0.31
KL (sample vs train)0.010
Training step time4.4s
Total time (incl. sampling + checkpoint)~100s

The model hits max_tokens on every episode (Qwen3-0.6B uses extended thinking, consuming token budget on <think> blocks before reaching Answer:). With more training steps, the model learns to use calc() and produce shorter responses.

Under the Hood

trajectory_to_data() converts a Trajectory (list of Transitions) into a list of Datums. Each Datum's loss_fn_inputs contains:

FieldTypePurpose
target_tokensint listNext-token prediction targets (left-shifted)
logprobsfloat listLog-probabilities under the sampling policy (for importance sampling ratio)
advantagesfloat listGRPO advantages: reward - group_mean_reward, zero for environment tokens
maskfloat list1.0 for agent tokens, 0.0 for environment tokens

The training loss is importance sampling (GRPO): for each agent token, the loss is weighted by the advantage and the ratio of current policy probability to sampling policy probability. Environment tokens (mask=0) contribute zero loss.

Import pattern: Use flat imports like from mint.recipe import MessageEnv, EnvFromMessageEnv, get_tokenizer for environment helpers. Use recipe.rl.train.main() for the high-level training entry point.

On this page