RLHF Pipeline
This recipe demonstrates a minimal real RLHF flow on MinT:
- Stage 1 — SFT: train on helpful chat examples with
recipe.supervised.train.main(). - Stage 2 — PRM / preference model: train on chosen/rejected pairs with
forward_backward_custom()and a Bradley-Terry loss. - Stage 3 — RL: resume from the Stage 1 SFT checkpoint with
load_checkpoint_path, then train withrecipe.rl.train.main()using aMessageEnv. The environment scores responses with the PRM sampler when available, or uses a rule-based fallback if Stage 2 fails.
This page matches recipes/rlhf_pipeline.py.
Why This Is Different From a Fake RLHF Demo
The recipe does not hand-build synthetic PPO datums. It uses the real recipe stack:
SFT: recipe.supervised.train.main()
PRM: TrainingClient.forward_backward_custom()
RL: recipe.rl.train.main(loss_fn="importance_sampling")Stage 1: SFT
log_path = Path("/tmp/mint-rlhf-sft")
config = recipe.supervised.train.Config(
log_path=str(log_path),
model_name=MODEL,
renderer_name="qwen3",
dataset_builder=ListSFTDatasetBuilder(
conversations=SFT_CONVERSATIONS,
model_name=MODEL,
renderer_name="qwen3",
batch_size=2,
),
learning_rate=1e-5,
lora_rank=16,
max_steps=1,
save_every=999,
eval_every=999,
)
await recipe.supervised.train.main(config=config)
sft_checkpoint_path = _read_last_state_checkpoint(log_path)The SFT dataset uses conversation_to_datum() under the hood. The final SFT state_path is then passed into Stage 3 RL.
Stage 2: PRM / Preference Model
The preference model uses the same Bradley-Terry custom-loss pattern as the DPO recipe:
prm_client = service_client.create_lora_training_client(
base_model=MODEL,
rank=16,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
data = build_preference_data(prm_client.get_tokenizer())
result = prm_client.forward_backward_custom(
data,
pairwise_preference_loss,
).result()
prm_client.optim_step(types.AdamParams(learning_rate=1e-5)).result()
prm_sampler = prm_client.save_weights_and_get_sampling_client(name="rlhf-prm")If this stage fails, the script prints a warning and continues to Stage 3 with a rule-based reward.
Stage 3: RL with PRM Reward
The PRM sampler is passed through the RL dataset builder chain:
RLHFDatasetBuilder(prm_sampler)
↓
RLHFDataset(prm_sampler)
↓
RLHFEnvGroupBuilder(prm_sampler)
↓
RLHFMessageEnv(prm_sampler)The MessageEnv.step() method scores the model response:
def rule_based_reward(response: str) -> float:
return 0.5 if len(response.split()) >= 6 and any(
word in response.lower()
for word in ["specific", "automatic", "risk", "restore", "actionable"]
) else -0.1
class RLHFMessageEnv(MessageEnv):
async def step(self, message):
response = _extract_content(message).strip()
if self.prm_sampler is not None:
reward = _score_with_prm(self.prm_sampler, self.prompt, response)
else:
reward = rule_based_reward(response)
return MessageStepResult(
reward=reward,
episode_done=True,
next_messages=[],
metrics={"reward_source_prm": float(self.prm_sampler is not None)},
)The RL config uses importance_sampling:
config = recipe.rl.train.Config(
learning_rate=1e-5,
dataset_builder=RLHFDatasetBuilder(..., prm_sampler=prm_sampler),
model_name=MODEL,
load_checkpoint_path=sft_checkpoint_path,
renderer_name="qwen3",
lora_rank=16,
max_tokens=32,
temperature=0.7,
kl_penalty_coef=0.0,
loss_fn="importance_sampling",
max_steps=1,
save_every=999,
eval_every=999,
)
await recipe.rl.train.main(config=config)View full source: https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/rlhf_pipeline.py
Verified Run
Verified on MinT with Qwen/Qwen3-0.6B, one step per stage:
| Stage | Result |
|---|---|
| SFT | Completed, train_mean_nll=5.687325 |
| PRM | Completed, loss=8.358253, pair_accuracy=0.50 |
| RL | Completed, resumed from SFT state_path, env/all/reward/total=-0.100, checkpoint saved |
The full smoke log is from a minimal run. It verifies the shape and APIs, not final model quality.
This is a small recipe for validating the RLHF wiring. For a production RLHF run, use more SFT data, many preference pairs, a stronger reward model, longer RL training, and separate validation prompts.
Shape
chat examples
│
▼
Stage 1 SFT ──▶ SFT state_path ─────┐
│ load_checkpoint_path
preference pairs │
│ │
▼ │
Stage 2 PRM ──▶ PRM sampler │
│ │
▼ ▼
Stage 3 MessageEnv RL resumes SFT weights
│
▼
final policy checkpoint