CustomizeDPO
RLHF Pipeline
这个 recipe 展示一个在 MinT 上真实可跑的最小 RLHF flow:
- Stage 1 — SFT:用
recipe.supervised.train.main()在 helpful chat examples 上训练。 - Stage 2 — PRM / preference model:用
forward_backward_custom()和 Bradley-Terry loss 在 chosen/rejected pairs 上训练。 - Stage 3 — RL:通过
load_checkpoint_path从 Stage 1 SFT checkpoint 恢复,然后用recipe.rl.train.main()和MessageEnv训练。environment 优先用 PRM sampler 给 response 打分;如果 Stage 2 失败,就 fallback 到 rule-based reward。
这个页面和 recipes/rlhf_pipeline.py 对齐。
为什么这不是假的 RLHF demo
这个 recipe 不手写 synthetic PPO datums。它使用真实 recipe stack:
SFT: recipe.supervised.train.main()
PRM: TrainingClient.forward_backward_custom()
RL: recipe.rl.train.main(loss_fn="importance_sampling")Stage 1:SFT
log_path = Path("/tmp/mint-rlhf-sft")
config = recipe.supervised.train.Config(
log_path=str(log_path),
model_name=MODEL,
renderer_name="qwen3",
dataset_builder=ListSFTDatasetBuilder(
conversations=SFT_CONVERSATIONS,
model_name=MODEL,
renderer_name="qwen3",
batch_size=2,
),
learning_rate=1e-5,
lora_rank=16,
max_steps=1,
save_every=999,
eval_every=999,
)
await recipe.supervised.train.main(config=config)
sft_checkpoint_path = _read_last_state_checkpoint(log_path)SFT dataset 内部使用 conversation_to_datum()。最终 SFT state_path 会传给 Stage 3 RL。
Stage 2:PRM / Preference Model
preference model 使用和 DPO recipe 一样的 Bradley-Terry custom-loss pattern:
prm_client = service_client.create_lora_training_client(
base_model=MODEL,
rank=16,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
data = build_preference_data(prm_client.get_tokenizer())
result = prm_client.forward_backward_custom(
data,
pairwise_preference_loss,
).result()
prm_client.optim_step(types.AdamParams(learning_rate=1e-5)).result()
prm_sampler = prm_client.save_weights_and_get_sampling_client(name="rlhf-prm")如果这个阶段失败,脚本会打印 warning,并继续进入 Stage 3,使用 rule-based reward。
Stage 3:RL with PRM Reward
PRM sampler 会一路传过 RL dataset builder chain:
RLHFDatasetBuilder(prm_sampler)
↓
RLHFDataset(prm_sampler)
↓
RLHFEnvGroupBuilder(prm_sampler)
↓
RLHFMessageEnv(prm_sampler)MessageEnv.step() 给 model response 打分:
def rule_based_reward(response: str) -> float:
return 0.5 if len(response.split()) >= 6 and any(
word in response.lower()
for word in ["specific", "automatic", "risk", "restore", "actionable"]
) else -0.1
class RLHFMessageEnv(MessageEnv):
async def step(self, message):
response = _extract_content(message).strip()
if self.prm_sampler is not None:
reward = _score_with_prm(self.prm_sampler, self.prompt, response)
else:
reward = rule_based_reward(response)
return MessageStepResult(
reward=reward,
episode_done=True,
next_messages=[],
metrics={"reward_source_prm": float(self.prm_sampler is not None)},
)RL config 使用 importance_sampling:
config = recipe.rl.train.Config(
learning_rate=1e-5,
dataset_builder=RLHFDatasetBuilder(..., prm_sampler=prm_sampler),
model_name=MODEL,
load_checkpoint_path=sft_checkpoint_path,
renderer_name="qwen3",
lora_rank=16,
max_tokens=32,
temperature=0.7,
kl_penalty_coef=0.0,
loss_fn="importance_sampling",
max_steps=1,
save_every=999,
eval_every=999,
)
await recipe.rl.train.main(config=config)完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/rlhf_pipeline.py
Verified Run
已在 MinT 上验证:Qwen/Qwen3-0.6B,每个 stage 1 step:
| Stage | Result |
|---|---|
| SFT | 完成,train_mean_nll=5.687325 |
| PRM | 完成,loss=8.358253,pair_accuracy=0.50 |
| RL | 完成,从 SFT state_path 恢复,env/all/reward/total=-0.100,checkpoint 已保存 |
完整 smoke log 来自一个最小运行。它验证的是形状和 API,不代表最终模型质量。
这是一个用于验证 RLHF wiring 的小 recipe。生产级 RLHF 需要更多 SFT 数据、更多 preference pairs、更强 reward model、更长 RL 训练,以及单独的 validation prompts。
Shape
chat examples
│
▼
Stage 1 SFT ──▶ SFT state_path ─────┐
│ load_checkpoint_path
preference pairs │
│ │
▼ │
Stage 2 PRM ──▶ PRM sampler │
│ │
▼ ▼
Stage 3 MessageEnv RL resumes SFT weights
│
▼
final policy checkpoint