Mind Lab Toolkit (MinT)
CustomizeDPO

RLHF Pipeline

这个 recipe 展示一个在 MinT 上真实可跑的最小 RLHF flow:

  1. Stage 1 — SFT:用 recipe.supervised.train.main() 在 helpful chat examples 上训练。
  2. Stage 2 — PRM / preference model:用 forward_backward_custom() 和 Bradley-Terry loss 在 chosen/rejected pairs 上训练。
  3. Stage 3 — RL:通过 load_checkpoint_path 从 Stage 1 SFT checkpoint 恢复,然后用 recipe.rl.train.main()MessageEnv 训练。environment 优先用 PRM sampler 给 response 打分;如果 Stage 2 失败,就 fallback 到 rule-based reward。

这个页面和 recipes/rlhf_pipeline.py 对齐。

为什么这不是假的 RLHF demo

这个 recipe 不手写 synthetic PPO datums。它使用真实 recipe stack:

SFT: recipe.supervised.train.main()
PRM: TrainingClient.forward_backward_custom()
RL:  recipe.rl.train.main(loss_fn="importance_sampling")

Stage 1:SFT

log_path = Path("/tmp/mint-rlhf-sft")
config = recipe.supervised.train.Config(
    log_path=str(log_path),
    model_name=MODEL,
    renderer_name="qwen3",
    dataset_builder=ListSFTDatasetBuilder(
        conversations=SFT_CONVERSATIONS,
        model_name=MODEL,
        renderer_name="qwen3",
        batch_size=2,
    ),
    learning_rate=1e-5,
    lora_rank=16,
    max_steps=1,
    save_every=999,
    eval_every=999,
)

await recipe.supervised.train.main(config=config)
sft_checkpoint_path = _read_last_state_checkpoint(log_path)

SFT dataset 内部使用 conversation_to_datum()。最终 SFT state_path 会传给 Stage 3 RL。

Stage 2:PRM / Preference Model

preference model 使用和 DPO recipe 一样的 Bradley-Terry custom-loss pattern:

prm_client = service_client.create_lora_training_client(
    base_model=MODEL,
    rank=16,
    train_mlp=True,
    train_attn=True,
    train_unembed=True,
)

data = build_preference_data(prm_client.get_tokenizer())

result = prm_client.forward_backward_custom(
    data,
    pairwise_preference_loss,
).result()

prm_client.optim_step(types.AdamParams(learning_rate=1e-5)).result()
prm_sampler = prm_client.save_weights_and_get_sampling_client(name="rlhf-prm")

如果这个阶段失败,脚本会打印 warning,并继续进入 Stage 3,使用 rule-based reward。

Stage 3:RL with PRM Reward

PRM sampler 会一路传过 RL dataset builder chain:

RLHFDatasetBuilder(prm_sampler)

RLHFDataset(prm_sampler)

RLHFEnvGroupBuilder(prm_sampler)

RLHFMessageEnv(prm_sampler)

MessageEnv.step() 给 model response 打分:

def rule_based_reward(response: str) -> float:
    return 0.5 if len(response.split()) >= 6 and any(
        word in response.lower()
        for word in ["specific", "automatic", "risk", "restore", "actionable"]
    ) else -0.1

class RLHFMessageEnv(MessageEnv):
    async def step(self, message):
        response = _extract_content(message).strip()
        if self.prm_sampler is not None:
            reward = _score_with_prm(self.prm_sampler, self.prompt, response)
        else:
            reward = rule_based_reward(response)

        return MessageStepResult(
            reward=reward,
            episode_done=True,
            next_messages=[],
            metrics={"reward_source_prm": float(self.prm_sampler is not None)},
        )

RL config 使用 importance_sampling

config = recipe.rl.train.Config(
    learning_rate=1e-5,
    dataset_builder=RLHFDatasetBuilder(..., prm_sampler=prm_sampler),
    model_name=MODEL,
    load_checkpoint_path=sft_checkpoint_path,
    renderer_name="qwen3",
    lora_rank=16,
    max_tokens=32,
    temperature=0.7,
    kl_penalty_coef=0.0,
    loss_fn="importance_sampling",
    max_steps=1,
    save_every=999,
    eval_every=999,
)

await recipe.rl.train.main(config=config)

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/rlhf_pipeline.py

Verified Run

已在 MinT 上验证:Qwen/Qwen3-0.6B,每个 stage 1 step:

StageResult
SFT完成,train_mean_nll=5.687325
PRM完成,loss=8.358253pair_accuracy=0.50
RL完成,从 SFT state_path 恢复,env/all/reward/total=-0.100,checkpoint 已保存

完整 smoke log 来自一个最小运行。它验证的是形状和 API,不代表最终模型质量。

这是一个用于验证 RLHF wiring 的小 recipe。生产级 RLHF 需要更多 SFT 数据、更多 preference pairs、更强 reward model、更长 RL 训练,以及单独的 validation prompts。

Shape

chat examples


Stage 1 SFT ──▶ SFT state_path ─────┐
                                     │ load_checkpoint_path
preference pairs                     │
  │                                  │
  ▼                                  │
Stage 2 PRM ──▶ PRM sampler          │
                  │                  │
                  ▼                  ▼
          Stage 3 MessageEnv RL resumes SFT weights


             final policy checkpoint

本页目录