Mind Lab Toolkit (MinT)
CustomizeDPO

RLHF Pipeline

This recipe demonstrates a minimal real RLHF flow on MinT:

  1. Stage 1 — SFT: train on helpful chat examples with recipe.supervised.train.main().
  2. Stage 2 — PRM / preference model: train on chosen/rejected pairs with forward_backward_custom() and a Bradley-Terry loss.
  3. Stage 3 — RL: resume from the Stage 1 SFT checkpoint with load_checkpoint_path, then train with recipe.rl.train.main() using a MessageEnv. The environment scores responses with the PRM sampler when available, or uses a rule-based fallback if Stage 2 fails.

This page matches recipes/rlhf_pipeline.py.

Why This Is Different From a Fake RLHF Demo

The recipe does not hand-build synthetic PPO datums. It uses the real recipe stack:

SFT: recipe.supervised.train.main()
PRM: TrainingClient.forward_backward_custom()
RL:  recipe.rl.train.main(loss_fn="importance_sampling")

Stage 1: SFT

log_path = Path("/tmp/mint-rlhf-sft")
config = recipe.supervised.train.Config(
    log_path=str(log_path),
    model_name=MODEL,
    renderer_name="qwen3",
    dataset_builder=ListSFTDatasetBuilder(
        conversations=SFT_CONVERSATIONS,
        model_name=MODEL,
        renderer_name="qwen3",
        batch_size=2,
    ),
    learning_rate=1e-5,
    lora_rank=16,
    max_steps=1,
    save_every=999,
    eval_every=999,
)

await recipe.supervised.train.main(config=config)
sft_checkpoint_path = _read_last_state_checkpoint(log_path)

The SFT dataset uses conversation_to_datum() under the hood. The final SFT state_path is then passed into Stage 3 RL.

Stage 2: PRM / Preference Model

The preference model uses the same Bradley-Terry custom-loss pattern as the DPO recipe:

prm_client = service_client.create_lora_training_client(
    base_model=MODEL,
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

data = build_preference_data(prm_client.get_tokenizer())

result = prm_client.forward_backward_custom(
    data,
    pairwise_preference_loss,
).result()

prm_client.optim_step(types.AdamParams(learning_rate=1e-5)).result()
prm_sampler = prm_client.save_weights_and_get_sampling_client(name="rlhf-prm")

If this stage fails, the script prints a warning and continues to Stage 3 with a rule-based reward.

Stage 3: RL with PRM Reward

The PRM sampler is passed through the RL dataset builder chain:

RLHFDatasetBuilder(prm_sampler)

RLHFDataset(prm_sampler)

RLHFEnvGroupBuilder(prm_sampler)

RLHFMessageEnv(prm_sampler)

The MessageEnv.step() method scores the model response:

def rule_based_reward(response: str) -> float:
    return 0.5 if len(response.split()) >= 6 and any(
        word in response.lower()
        for word in ["specific", "automatic", "risk", "restore", "actionable"]
    ) else -0.1

class RLHFMessageEnv(MessageEnv):
    async def step(self, message):
        response = _extract_content(message).strip()
        if self.prm_sampler is not None:
            reward = _score_with_prm(self.prm_sampler, self.prompt, response)
        else:
            reward = rule_based_reward(response)

        return MessageStepResult(
            reward=reward,
            episode_done=True,
            next_messages=[],
            metrics={"reward_source_prm": float(self.prm_sampler is not None)},
        )

The RL config uses importance_sampling:

config = recipe.rl.train.Config(
    learning_rate=1e-5,
    dataset_builder=RLHFDatasetBuilder(..., prm_sampler=prm_sampler),
    model_name=MODEL,
    load_checkpoint_path=sft_checkpoint_path,
    renderer_name="qwen3",
    lora_rank=16,
    max_tokens=32,
    temperature=0.7,
    kl_penalty_coef=0.0,
    loss_fn="importance_sampling",
    max_steps=1,
    save_every=999,
    eval_every=999,
)

await recipe.rl.train.main(config=config)

View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/rlhf_pipeline.py

Verified Run

Verified on MinT with Qwen/Qwen3-0.6B, one step per stage:

StageResult
SFTCompleted, train_mean_nll=5.687325
PRMCompleted, loss=8.358253, pair_accuracy=0.50
RLCompleted, resumed from SFT state_path, env/all/reward/total=-0.100, checkpoint saved

The full smoke log is from a minimal run. It verifies the shape and APIs, not final model quality.

This is a small recipe for validating the RLHF wiring. For a production RLHF run, use more SFT data, many preference pairs, a stronger reward model, longer RL training, and separate validation prompts.

Shape

chat examples


Stage 1 SFT ──▶ SFT state_path ─────┐
                                     │ load_checkpoint_path
preference pairs                     │
  │                                  │
  ▼                                  │
Stage 2 PRM ──▶ PRM sampler          │
                  │                  │
                  ▼                  ▼
          Stage 3 MessageEnv RL resumes SFT weights


             final policy checkpoint

On this page