API 参考
TrainingClient
模型训练操作的核心接口
Forward/Backward 操作
仅 Forward
forward(data, loss_fn, loss_fn_config=None)
forward_async(data, loss_fn, loss_fn_config=None)计算 loss,不计算 gradient(评估模式)。
Forward-Backward
forward_backward(data, loss_fn, loss_fn_config=None)
forward_backward_async(data, loss_fn, loss_fn_config=None)计算 loss 和 gradient,用于训练。
内置 Loss 函数:
"cross_entropy""importance_sampling""ppo""cispo""dro"
自定义 Loss 函数
forward_backward_custom(data, custom_loss_fn)
forward_backward_custom_async(data, custom_loss_fn)使用自定义 loss 函数,签名:
def custom_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict]优化
optim_step(params)
optim_step_async(params)使用 Adam optimizer 更新模型参数。
示例:
optim_future = training_client.optim_step(
types.AdamParams(learning_rate=1e-4)
)状态管理
保存
save_state(name)
save_state_async(name)保存 weight 和 optimizer 状态(用于恢复训练)。
save_weights_for_sampler(name)
save_weights_for_sampler_async(name)保存为 inference 优化的 weight(更快,无 optimizer 状态)。
加载
load_state(path)
load_state_async(path)仅加载 weight。
load_state_with_optimizer(path)
load_state_with_optimizer_async(path)加载 weight 和 optimizer 状态。
创建 Sampling Client
create_sampling_client()从当前 weight 创建 SamplingClient(不保存)。
save_weights_and_get_sampling_client(name)保存 weight 并返回 SamplingClient,一步完成。
工具方法
get_info()获取模型配置和元数据。
get_tokenizer()获取模型的 tokenizer,用于编码/解码文本。
示例
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-4B-Instruct-2507"
)
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")