Chat Supervised Fine-Tuning
This tutorial demonstrates supervised fine-tuning (SFT) on chat-formatted data using MinT.
What You’ll Learn
- Load and process multi-turn conversations from HuggingFace
- Apply chat templates for model-specific formatting
- Implement loss masking to train only on assistant responses
- Run the SFT training loop
- Evaluate and sample from the fine-tuned model
Datasets
We support two chat datasets:
| Dataset | Source | Size | Train On | Use Case |
|---|---|---|---|---|
no_robots | HuggingFaceH4/no_robots | 9.5K train + 500 test | All assistant messages | Quick experiments |
tulu3 | allenai/tulu-3-sft-mixture | 939K samples | Last assistant message | Large-scale training |
SFT vs RL
| Method | Training Signal | When to Use |
|---|---|---|
| SFT | Expert demonstrations (input → output pairs) | When you have high-quality labeled data |
| RL | Reward signal (correct/incorrect) | When defining correctness programmatically is easier than providing examples |
SFT is simpler: no sampling required, no reward function to design. The model learns to predict the next token on training data.
Step 0: Setup
Install required packages:
pip install -q datasets transformers mintLoad your API key:
import os
from dotenv import load_dotenv
load_dotenv()
# MinT uses MINT_API_KEY for authentication
if os.environ.get('MINT_API_KEY'):
print("API key loaded")
else:
print("WARNING: MINT_API_KEY not found!")Connect to MinT:
import mint
from mint import types
service_client = mint.ServiceClient()
print("Connected to MinT")Step 1: Configuration
Configure your training run:
# ========== CONFIGURATION ==========
# Dataset: "no_robots" or "tulu3"
DATASET = "no_robots"
# Model
BASE_MODEL = "Qwen/Qwen3-0.6B"
LORA_RANK = 16
# Training
NUM_STEPS = 50 if DATASET == "no_robots" else 100
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
MAX_LENGTH = 2048
# Loss masking: which tokens to train on
# "all_assistant" = train on all assistant responses
# "last_assistant" = train only on the final assistant response
TRAIN_ON = "all_assistant" if DATASET == "no_robots" else "last_assistant"
print(f"Dataset: {DATASET}")
print(f"Model: {BASE_MODEL}")
print(f"Steps: {NUM_STEPS}, Batch: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print(f"Max length: {MAX_LENGTH}, Train on: {TRAIN_ON}")Parameter choices:
all_assistant: For smaller datasets likeno_robots, train on every assistant turn to maximize data utilizationlast_assistant: For large datasets liketulu3, training only on the final response is sufficient and faster
Step 2: Tokenizer & Chat Template
Why Chat Templates Matter
Language models don’t understand “roles” natively. A chat template converts structured messages into a flat token sequence that the model can process:
[User message] → <|im_start|>user\nHello!<|im_end|>
[Assistant] → <|im_start|>assistant\nHi there!<|im_end|>Different models use different templates. We use HuggingFace’s apply_chat_template() to handle this automatically.
Load the Tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
print(f"Vocab size: {tokenizer.vocab_size:,}")
print(f"EOS token: {tokenizer.eos_token!r} (id={tokenizer.eos_token_id})")Tokenize with Loss Masking
The key insight: we want to compute loss only on assistant tokens. This means:
- User messages:
weight = 0(don’t train on these) - Assistant messages:
weight = 1(train on these)
import numpy as np
def tokenize_conversation(
messages: list[dict],
tokenizer,
max_length: int,
train_on: str = "all_assistant",
) -> tuple[list[int], np.ndarray]:
"""
Tokenize a conversation and compute loss weights.
Returns:
input_ids: Token IDs for the full conversation
weights: Loss weights (1.0 for tokens to train on, 0.0 otherwise)
"""
# Tokenize message by message to track boundaries
all_tokens = []
all_weights = []
for i, msg in enumerate(messages):
# Build partial conversation up to this message
partial = messages[:i+1]
# Apply chat template
text = tokenizer.apply_chat_template(
partial,
tokenize=False,
add_generation_prompt=False,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
# Find new tokens added by this message
prev_len = len(all_tokens)
new_tokens = tokens[prev_len:]
# Determine weight for this message
is_assistant = msg.get("role") == "assistant"
is_last = (i == len(messages) - 1)
if train_on == "all_assistant":
weight = 1.0 if is_assistant else 0.0
elif train_on == "last_assistant":
weight = 1.0 if (is_assistant and is_last) else 0.0
else:
weight = 1.0 # train on all tokens
all_tokens.extend(new_tokens)
all_weights.extend([weight] * len(new_tokens))
# Truncate to max_length
if len(all_tokens) > max_length:
all_tokens = all_tokens[:max_length]
all_weights = all_weights[:max_length]
return all_tokens, np.array(all_weights, dtype=np.float32)Test the tokenization:
demo_messages = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there! How can I help you?"},
]
tokens, weights = tokenize_conversation(demo_messages, tokenizer, MAX_LENGTH, TRAIN_ON)
print(f"Tokens: {len(tokens)}")
print(f"Trainable tokens: {int(weights.sum())}")
print(f"\nDecoded: {tokenizer.decode(tokens)!r}")Step 3: Create Training Datum
MinT expects data in Datum format. For next-token prediction:
model_input: tokens[:-1] (all tokens except the last)target_tokens: tokens[1:] (all tokens except the first)weights: weights[1:] (aligned with targets)
def conversation_to_datum(
messages: list[dict],
tokenizer,
max_length: int,
train_on: str,
) -> types.Datum:
"""
Convert a conversation to a training Datum.
The model predicts token[i+1] from token[0:i+1], so:
- input = tokens[:-1]
- target = tokens[1:]
- weights = weights[1:] (shifted to align with targets)
"""
tokens, weights = tokenize_conversation(messages, tokenizer, max_length, train_on)
if len(tokens) < 2:
raise ValueError("Conversation too short")
# Next-token prediction format
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
target_weights = weights[1:] # Shift weights to align with targets
return types.Datum(
model_input=types.ModelInput.from_ints(input_tokens),
loss_fn_inputs={
"target_tokens": list(target_tokens),
"weights": target_weights.tolist(),
},
)Test the conversion:
datum = conversation_to_datum(demo_messages, tokenizer, MAX_LENGTH, TRAIN_ON)
print(f"Input length: {datum.model_input.length}")
print(f"Target length: {len(datum.loss_fn_inputs['target_tokens'])}")
print(f"Trainable tokens: {sum(datum.loss_fn_inputs['weights']):.0f}")Step 4: Load Dataset
We provide a simple dataset wrapper that handles batching and shuffling:
from datasets import load_dataset
from dataclasses import dataclass
import random
@dataclass
class ChatDataset:
"""Simple chat dataset with batching."""
data: list[list[dict]] # List of conversations (each is list of messages)
index: int = 0
def get_batch(self, batch_size: int) -> list[list[dict]]:
"""Get next batch of conversations."""
batch = []
for _ in range(batch_size):
if self.index >= len(self.data):
self.index = 0 # Wrap around
random.shuffle(self.data)
batch.append(self.data[self.index])
self.index += 1
return batch
def __len__(self) -> int:
return len(self.data)
def load_chat_dataset(dataset_name: str, seed: int = 42) -> tuple[ChatDataset, ChatDataset]:
"""Load train and test datasets."""
random.seed(seed)
if dataset_name == "no_robots":
ds = load_dataset("HuggingFaceH4/no_robots")
train_data = [row["messages"] for row in ds["train"]]
test_data = [row["messages"] for row in ds["test"]]
elif dataset_name == "tulu3":
ds = load_dataset("allenai/tulu-3-sft-mixture", split="train")
ds = ds.shuffle(seed=seed)
all_data = [row["messages"] for row in ds]
# Split: first 1024 for test, rest for train
test_data = all_data[:1024]
train_data = all_data[1024:]
else:
raise ValueError(f"Unknown dataset: {dataset_name}")
random.shuffle(train_data)
return ChatDataset(train_data), ChatDataset(test_data)Load and inspect:
train_dataset, test_dataset = load_chat_dataset(DATASET)
print(f"Train: {len(train_dataset)} conversations")
print(f"Test: {len(test_dataset)} conversations")
# Show sample
sample = train_dataset.get_batch(1)[0]
print(f"\nSample conversation ({len(sample)} messages):")
for msg in sample[:3]: # Show first 3 messages
content = msg['content'][:80] + "..." if len(msg['content']) > 80 else msg['content']
print(f" [{msg['role']}]: {content}")Step 5: Create Training Client
Create a LoRA training client. LoRA (Low-Rank Adaptation) allows efficient fine-tuning by training small adapter matrices instead of all model weights.
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL,
rank=LORA_RANK,
train_mlp=True,
train_attn=True,
train_unembed=True,
)
print(f"Training client created: {BASE_MODEL}")
print(f"LoRA rank: {LORA_RANK}")LoRA Parameters:
rank: Size of the low-rank matrices. Higher = more capacity, slower trainingtrain_mlp: Train feed-forward (MLP) layerstrain_attn: Train attention layerstrain_unembed: Train output projection
Step 6: Training Loop
The training loop follows a simple pattern:
for each step:
1. Get batch of conversations
2. Convert to Datums (tokenize + compute weights)
3. forward_backward: compute loss and gradients
4. optim_step: update model weightsmetrics_history = []
print(f"Starting SFT training: {NUM_STEPS} steps")
print(f"Batch: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print()
for step in range(NUM_STEPS):
# Get batch of conversations
batch = train_dataset.get_batch(BATCH_SIZE)
# Convert to Datums
datums = []
for messages in batch:
try:
datum = conversation_to_datum(messages, tokenizer, MAX_LENGTH, TRAIN_ON)
datums.append(datum)
except Exception as e:
continue # Skip malformed conversations
if not datums:
continue
# Forward-backward pass: compute loss and gradients
fwdbwd_result = training_client.forward_backward(
datums,
loss_fn="cross_entropy",
).result()
# Compute loss from logprobs
total_loss = 0.0
total_weight = 0.0
for i, out in enumerate(fwdbwd_result.loss_fn_outputs):
logprobs = out['logprobs']
if hasattr(logprobs, 'tolist'):
logprobs = logprobs.tolist()
w = datums[i].loss_fn_inputs['weights']
if hasattr(w, 'tolist'):
w = w.tolist()
for lp, wt in zip(logprobs, w):
total_loss += -lp * wt
total_weight += wt
loss = total_loss / max(total_weight, 1)
# Optimization step: update weights
training_client.optim_step(
types.AdamParams(learning_rate=LEARNING_RATE)
).result()
metrics_history.append({"step": step, "loss": loss})
if step % 10 == 0 or step == NUM_STEPS - 1:
print(f"Step {step:3d}: loss={loss:.4f}")
print("\nTraining complete!")
print(f"Initial loss: {metrics_history[0]['loss']:.4f}")
print(f"Final loss: {metrics_history[-1]['loss']:.4f}")Understanding the loss:
- Loss is the negative log-likelihood (NLL) of correct tokens
- Lower loss = model is more confident in predicting the right tokens
- Expect loss to decrease over training
Step 7: Evaluate
Compute NLL on the test set to check for overfitting:
# Evaluate on test set
test_batch = test_dataset.get_batch(min(32, len(test_dataset)))
test_datums = []
for messages in test_batch:
try:
datum = conversation_to_datum(messages, tokenizer, MAX_LENGTH, TRAIN_ON)
test_datums.append(datum)
except Exception:
continue
if test_datums:
# Forward pass only (no gradients)
forward_result = training_client.forward(
test_datums,
loss_fn="cross_entropy",
).result()
# Compute loss from logprobs
total_loss = 0.0
total_weight = 0.0
for i, out in enumerate(forward_result.loss_fn_outputs):
logprobs = out['logprobs']
if hasattr(logprobs, 'tolist'):
logprobs = logprobs.tolist()
w = test_datums[i].loss_fn_inputs['weights']
if hasattr(w, 'tolist'):
w = w.tolist()
for lp, wt in zip(logprobs, w):
total_loss += -lp * wt
total_weight += wt
test_loss = total_loss / max(total_weight, 1)
print(f"Test NLL: {test_loss:.4f}")
else:
print("No valid test samples")Interpreting results:
- Test loss close to train loss = good generalization
- Test loss much higher than train loss = overfitting (reduce training steps or increase data)
Step 8: Visualize
Plot the training curve:
import matplotlib.pyplot as plt
steps = [m['step'] for m in metrics_history]
losses = [m['loss'] for m in metrics_history]
plt.figure(figsize=(10, 5))
plt.plot(steps, losses, 'b-', linewidth=2)
plt.xlabel('Step')
plt.ylabel('Loss (NLL)')
plt.title(f'{DATASET.upper()} SFT Training')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{DATASET}_training.png', dpi=150)
plt.show()Step 9: Generate Sample
Test the fine-tuned model by generating a response:
# Save weights and get sampling client
sampling_client = training_client.save_weights_and_get_sampling_client(
name=f"{DATASET}-sft-demo"
)
# Create a test prompt
test_messages = [
{"role": "user", "content": "Write a haiku about programming."}
]
# Apply chat template for generation
prompt_text = tokenizer.apply_chat_template(
test_messages,
tokenize=False,
add_generation_prompt=True,
)
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
# Sample from the model
sample_result = sampling_client.sample(
prompt=types.ModelInput.from_ints(prompt_tokens),
num_samples=1,
sampling_params=types.SamplingParams(
max_tokens=128,
temperature=0.7,
stop_token_ids=[tokenizer.eos_token_id],
),
).result()
response = tokenizer.decode(sample_result.sequences[0].tokens)
print("User: Write a haiku about programming.")
print(f"Assistant: {response}")Sampling parameters:
temperature=0.7: Controls randomness. Lower = more deterministic, higher = more creativestop_token_ids: Stop generating when EOS token is producedmax_tokens: Maximum tokens to generate
Step 10: Save Checkpoint
Save the final checkpoint for later use or continued training:
checkpoint = training_client.save_state(name=f"{DATASET}-sft-final").result()
print(f"Checkpoint saved: {checkpoint.path}")To resume training later:
resumed_client = service_client.create_training_client_from_state_with_optimizer(
checkpoint.path
)Summary
| Component | Implementation |
|---|---|
| Dataset | no_robots (9.5K) or tulu3 (939K) from HuggingFace |
| Tokenization | HuggingFace tokenizer with chat template |
| Loss masking | Train on assistant messages only |
| Training | cross_entropy loss with LoRA |
| Checkpointing | save_state() for weights + optimizer |
Key API Methods
# Setup
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(base_model=...)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Training
training_client.forward_backward(datums, loss_fn="cross_entropy")
training_client.optim_step(types.AdamParams(learning_rate=...))
# Evaluation
training_client.forward(datums, loss_fn="cross_entropy") # No gradients
# Inference
sampling_client = training_client.save_weights_and_get_sampling_client(name=...)
sampling_client.sample(prompt, num_samples, sampling_params)
# Checkpointing
checkpoint = training_client.save_state(name=...)