TrainingClient
Core interface for model training operations
Forward/Backward Operations
Forward Pass Only
forward(data, loss_fn, loss_fn_config=None)
forward_async(data, loss_fn, loss_fn_config=None)Compute loss without gradients (evaluation mode).
Forward-Backward Pass
forward_backward(data, loss_fn, loss_fn_config=None)
forward_backward_async(data, loss_fn, loss_fn_config=None)Compute loss and gradients for training.
Built-in Loss Functions:
"cross_entropy""importance_sampling""ppo""cispo""dro"
Custom Loss Function
forward_backward_custom(data, custom_loss_fn)
forward_backward_custom_async(data, custom_loss_fn)Use a custom loss function with signature:
def custom_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict]Optimization
optim_step(params)
optim_step_async(params)Update model parameters using Adam optimizer.
Example:
optim_future = training_client.optim_step(
types.AdamParams(learning_rate=1e-4)
)State Management
Saving
save_state(name)
save_state_async(name)Save weights and optimizer state (for resuming training).
save_weights_for_sampler(name)
save_weights_for_sampler_async(name)Save weights optimized for inference (faster, no optimizer state).
Loading
load_state(path)
load_state_async(path)Load weights only.
load_state_with_optimizer(path)
load_state_with_optimizer_async(path)Load weights and optimizer state.
Sampling Client Creation
create_sampling_client()Create a SamplingClient from current weights (without saving).
save_weights_and_get_sampling_client(name)Save weights and return a SamplingClient in one operation.
Utilities
get_info()Retrieve model configuration and metadata.
get_tokenizer()Get the model’s tokenizer for encoding/decoding text.
Example Pattern
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-4B-Instruct-2507"
)
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")