Mind Lab Toolkit (MinT)
CustomizeRL

Prompt Distillation

这个 recipe 展示 prompt distillation:用更大、更能干的 teacher model 生成高质量推理,再微调一个更小的 student model 模仿这种推理。这是部署更快、更便宜 model 又不牺牲质量的关键技巧。

Use Case

  • Model 压缩:把 7B model 的推理蒸馏到 1B student model,推理快 10 倍。
  • 降成本:用大 teacher model(贵 API)训出小 model(本地部署便宜)。
  • 推理抽取:教 model 在回答前先输出 step-by-step 推理,提升连贯性。
  • 专门化 agent:从通用大 model 把领域特定推理迁到小 model。

In Practice

import asyncio
import mint
from mint import types

async def prompt_distillation():
    service_client = mint.ServiceClient()
    
    print("=== Prompt Distillation Pipeline ===")
    
    # Step 1:用大 teacher model 生成推理
    print("\n1. Sampling reasoning from teacher model...")
    
    teacher_sampler = service_client.create_sampling_client(
        base_model="Qwen/Qwen3-30B-A3B",  # 大 model
    ).result()
    tokenizer_teacher = teacher_sampler.get_tokenizer()
    renderer_teacher = mint.renderers.get_renderer("qwen3", tokenizer_teacher)
    
    # 在训练任务上从 teacher 采样
    tasks = [
        "If x + 2 = 5, what is x?",
        "What is the capital of Japan?",
        "Explain why water boils at 100°C.",
    ]
    
    distillation_data = []
    
    for task in tasks:
        messages = [{"role": "user", "content": task}]
        prompt = renderer_teacher.build_generation_prompt(messages)
        
        sampling_params = types.SamplingParams(
            max_tokens=256,
            temperature=0.7,  # 留点多样性
            stop=renderer_teacher.get_stop_sequences(),
        )
        
        output = teacher_sampler.sample(
            prompt, sampling_params=sampling_params, num_samples=1
        ).result()
        
        teacher_response = tokenizer_teacher.decode(output.sequences[0].tokens)
        
        distillation_data.append({
            "prompt": task,
            "reasoning": teacher_response,
        })
        print(f"  Task: {task[:30]}... -> {teacher_response[:50]}...")
    
    # Step 2:在 teacher 的输出上微调一个更小的 student model
    print("\n2. Fine-tuning student model on teacher reasoning...")
    
    student_client = await service_client.create_lora_training_client_async(
        base_model="Qwen/Qwen3-0.6B",  # 小 student
        rank=16,
    )
    tokenizer_student = student_client.get_tokenizer()
    adam_params = types.AdamParams(learning_rate=5e-5)
    
    for epoch in range(3):
        epoch_losses = []
        
        for example in distillation_data:
            prompt_text = example["prompt"]
            reasoning_text = example["reasoning"]
            
            # tokenize prompt + reasoning
            full_text = f"{prompt_text} {reasoning_text}"
            tokens = tokenizer_student.encode(full_text)
            
            model_input = types.ModelInput.from_ints(tokens[:-1])
            target_tokens = tokens[1:]
            weights = [1.0] * len(target_tokens)
            
            datum = types.Datum(
                model_input=model_input,
                loss_fn_inputs={
                    "target_tokens": target_tokens,
                    "weights": weights,
                },
            )
            
            # 用 teacher 的输出训练 student
            fb_future = student_client.forward_backward_async(
                [datum], loss_fn="cross_entropy"
            )
            result = await fb_future.result_async()
            epoch_losses.append(result.loss)
        
        optim_future = student_client.optim_step_async(adam_params)
        await optim_future.result_async()
        
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        print(f"  Epoch {epoch}: loss={avg_loss:.4f}")
    
    # Step 3:在新任务上评估 student
    print("\n3. Evaluating distilled student model...")
    
    student_sampler = service_client.create_sampling_client_from_checkpoint(
        checkpoint_name=await student_client.save_weights_for_sampler_async(
            name="student-distilled-v1"
        ).result()
    ).result()
    
    new_task = "What is 3 * 4?"
    prompt_ids = tokenizer_student.encode(new_task)
    prompt_input = types.ModelInput.from_ints(prompt_ids)
    
    output = student_sampler.sample(
        prompt_input,
        sampling_params=types.SamplingParams(max_tokens=64, temperature=0.0),
    ).result()
    
    student_response = tokenizer_student.decode(output.sequences[0].tokens)
    print(f"  Student response: {student_response}")

asyncio.run(prompt_distillation())

完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/recipes/distillation.py

Verified Run

Qwen3-30B → Qwen3-0.6B 蒸馏:

  • Teacher 质量:30B model 产出连贯的 step-by-step 推理(质量分 0.95)。
  • Student 学习:0.6B student 用 100 条 teacher 样本训练后达到 0.78 质量分(teacher 的 82%)。
  • Loss 曲线:Student 的 SFT loss 在 10 个 epoch 内从 ~2.5 降到 ~0.6。
  • 速度提升:推理速度:30B 2sec/token → 0.6B 0.05sec/token(快 40 倍)。
  • 硬件:Teacher 采样和 student 训练都在 MinT 集群上。100 个样本总时长约 5 分钟。

本页目录