Mind Lab Toolkit (MinT)
API 参考

TrainingClient

模型训练操作的核心接口

Forward/Backward 操作

仅 Forward

forward(data, loss_fn, loss_fn_config=None)
forward_async(data, loss_fn, loss_fn_config=None)

计算 loss,不计算 gradient(评估模式)。

Forward-Backward

forward_backward(data, loss_fn, loss_fn_config=None)
forward_backward_async(data, loss_fn, loss_fn_config=None)

计算 loss 和 gradient,用于训练。

内置 Loss 函数

  • "cross_entropy"
  • "importance_sampling"
  • "ppo"
  • "cispo"
  • "dro"

自定义 Loss 函数

forward_backward_custom(data, custom_loss_fn)
forward_backward_custom_async(data, custom_loss_fn)

使用自定义 loss 函数,签名:

def custom_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict]

优化

optim_step(params)
optim_step_async(params)

使用 Adam optimizer 更新模型参数。

示例

optim_future = training_client.optim_step(
    types.AdamParams(learning_rate=1e-4)
)

状态管理

保存

save_state(name)
save_state_async(name)

保存 weight 和 optimizer 状态(用于恢复训练)。

save_weights_for_sampler(name)
save_weights_for_sampler_async(name)

保存为 inference 优化的 weight(更快,无 optimizer 状态)。

加载

load_state(path)
load_state_async(path)

仅加载 weight。

load_state_with_optimizer(path)
load_state_with_optimizer_async(path)

加载 weight 和 optimizer 状态。

创建 Sampling Client

create_sampling_client()

从当前 weight 创建 SamplingClient(不保存)。

save_weights_and_get_sampling_client(name)

保存 weight 并返回 SamplingClient,一步完成。

工具方法

get_info()

获取模型配置和元数据。

get_tokenizer()

获取模型的 tokenizer,用于编码/解码文本。

示例

training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-4B-Instruct-2507"
)

fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))

sampling_client = training_client.save_weights_and_get_sampling_client("my-model")

本页目录