CustomizeConcepts
Checkpoints & Weights
MinT 把 LoRA checkpoint 存在服务端。本页覆盖 checkpoint 的完整生命周期:保存供推理用、恢复继续训练、管理 checkpoint 元数据、把 weights 下载到本地部署或合并。
Concept
训练会产生两类 checkpoint:
- Inference checkpoint(
save_weights_for_sampler)—— 优化为采样用的 LoRA weights。用它创建SamplingClient做推理或评估。 - Training state(
save_state)—— 完整的训练状态,包含梯度、optimizer 矩量、loss 历史。用它从 checkpoint 续训不会丢动量。
Checkpoint 在服务端用名字标识,可以列出、设 TTL(time-to-live)、发布到 HuggingFace Hub、或下载到本地。流程:
训练循环:
forward_backward() -> optim_step() -> save_weights_for_sampler()
-> save_state()(用于续训)
-> get_checkpoint_metadata()
-> publish_checkpoint()Pattern
import mint
from mint import types
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-0.6B",
rank=16,
)
tokenizer = training_client.get_tokenizer()
# 跑几步
for step in range(10):
# 简化的 batch 构造
text = "Example training text for step {}".format(step)
tokens = tokenizer.encode(text)
model_input = types.ModelInput.from_ints(tokens[:-1])
target_tokens = tokens[1:]
weights = [1.0] * len(target_tokens)
datum = types.Datum(
model_input=model_input,
loss_fn_inputs={"target_tokens": target_tokens, "weights": weights},
)
result = training_client.forward_backward([datum], loss_fn="cross_entropy").result()
adam_params = types.AdamParams(learning_rate=5e-5)
training_client.optim_step(adam_params).result()
if step % 5 == 0:
# 保存供推理
sampling_client = training_client.save_weights_for_sampler(
name=f"checkpoint-step-{step}"
).result()
print(f"Saved checkpoint at step {step}")
# 保存完整 state,用于续训
training_client.save_state(name=f"state-step-{step}").result()
# 之后从 checkpoint 续训
checkpoint_state = "state-step-5"
resumed_client = service_client.create_lora_training_client_from_state(
checkpoint_state=checkpoint_state
).result()
# 或者从已保存的 checkpoint 创建一个 sampling client
sampling_client = service_client.create_sampling_client_from_checkpoint(
checkpoint_name="checkpoint-step-5"
).result()完整源码:https://github.com/MindLab-Research/mint-quickstart/blob/main/advanced/checkpoint.py
API Surface
| 方法 | 用途 | 返回 |
|---|---|---|
save_weights_for_sampler(name) | 保存 LoRA weights 供推理用 | SamplingClient(可直接使用) |
save_state(name) | 保存完整训练状态用于续训 | None |
create_lora_training_client_from_state(checkpoint_state) | 从 state 恢复训练 | TrainingClient |
create_sampling_client_from_checkpoint(checkpoint_name) | 从 checkpoint 加载用于推理 | SamplingClient |
get_checkpoint_metadata(name) | 查询 checkpoint 大小、创建时间、TTL | CheckpointMetadata |
set_checkpoint_ttl(name, ttl_hours) | 给 checkpoint 设过期时间 | None |
publish_checkpoint(name, hub_id) | 把 checkpoint 发布到 HuggingFace Hub | None |
Caveats & Pitfalls
- 重载 weights 后 sampler 失同步:保存 weights 之后,老的
SamplingClient仍然在用旧 weights 采样。永远用新 checkpoint 重新建 sampling client。 - State vs weights:
save_state()体积更大(含梯度、optimizer 矩量),但能保留训练动量。save_weights_for_sampler()轻量,但会丢 optimizer 历史。长训用 state,最终部署用 weights。 - Checkpoint 命名:checkpoint 名是用户自定义字符串。用描述性名字,比如
"math-v1-step-100"避免混淆。同类型 checkpoint(推理 vs state)名字必须唯一。 - TTL 默认值:保存的 checkpoint 默认会一直保留。设置 TTL 可以让旧 checkpoint 自动过期,回收服务端存储。
- Hub 发布:
publish_checkpoint()需要环境里有合法的 HuggingFace Hub token。见 Deployment: Publish to Hub。