API ReferenceTrainingClient

TrainingClient

Core interface for model training operations

Forward/Backward Operations

Forward Pass Only

forward(data, loss_fn, loss_fn_config=None)
forward_async(data, loss_fn, loss_fn_config=None)

Compute loss without gradients (evaluation mode).

Forward-Backward Pass

forward_backward(data, loss_fn, loss_fn_config=None)
forward_backward_async(data, loss_fn, loss_fn_config=None)

Compute loss and gradients for training.

Built-in Loss Functions:

  • "cross_entropy"
  • "importance_sampling"
  • "ppo"
  • "cispo"
  • "dro"

Custom Loss Function

forward_backward_custom(data, custom_loss_fn)
forward_backward_custom_async(data, custom_loss_fn)

Use a custom loss function with signature:

def custom_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict]

Optimization

optim_step(params)
optim_step_async(params)

Update model parameters using Adam optimizer.

Example:

optim_future = training_client.optim_step(
    types.AdamParams(learning_rate=1e-4)
)

State Management

Saving

save_state(name)
save_state_async(name)

Save weights and optimizer state (for resuming training).

save_weights_for_sampler(name)
save_weights_for_sampler_async(name)

Save weights optimized for inference (faster, no optimizer state).

Loading

load_state(path)
load_state_async(path)

Load weights only.

load_state_with_optimizer(path)
load_state_with_optimizer_async(path)

Load weights and optimizer state.

Sampling Client Creation

create_sampling_client()

Create a SamplingClient from current weights (without saving).

save_weights_and_get_sampling_client(name)

Save weights and return a SamplingClient in one operation.

Utilities

get_info()

Retrieve model configuration and metadata.

get_tokenizer()

Get the model’s tokenizer for encoding/decoding text.

Example Pattern

training_client = service_client.create_lora_training_client(
    base_model="Qwen/Qwen3-4B-Instruct-2507"
)
 
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
 
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")