MinT Quickstart
What is MinT?
MinT (Mind Lab Toolkit) is an open infrastructure for training language models using:
- SFT (Supervised Fine-Tuning): Learn from labeled examples (input → output pairs)
- RL (Reinforcement Learning): Learn from reward signals (trial and error)
MinT uses LoRA (Low-Rank Adaptation) to efficiently fine-tune large models without modifying all parameters.
What You’ll Learn
In this tutorial, we’ll train a model to solve multiplication problems using a two-stage approach:
- Stage 1 (SFT): Teach the model multiplication with labeled examples
- Stage 2 (RL): Load the SFT model and refine it with reward signals
This demonstrates the complete workflow: SFT → Save → Load → RL
Prerequisites
- Python >= 3.11
- A MinT API key
Step 0: Installation
Install the MinT SDK from the git repository:
pip install git+https://github.com/MindLab-Research/mindlab-toolkit.git python-dotenv matplotlib numpyStep 1: Configure Your API Key
MinT requires an API key for authentication. There are two ways to set it up:
Option A: Using a .env file (Recommended)
Create a file named .env in your project directory:
MINT_API_KEY=sk-mint-your-api-key-hereOption B: Set environment variable directly
import os
os.environ['MINT_API_KEY'] = 'sk-mint-your-api-key-here'Security Note: Never commit your API key to version control. Add .env to your .gitignore file.
Using Tinker SDK
If you have existing code using the Tinker SDK, you can use it to connect to MinT by setting these environment variables:
pip install tinkerTINKER_BASE_URL=https://mint.macaron.im/
TINKER_API_KEY=<your-mint-api-key>Note: Use your MinT API key (starts with sk-mint-). All code in this tutorial works with import tinker instead of import mint.
HuggingFace Mirror
If you have network issues accessing HuggingFace, set the mirror endpoint before importing mint:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import mint # Must be after setting HF_ENDPOINTStep 2: Connect to MinT Server
The ServiceClient is your entry point to MinT. It handles:
- Authentication with the server
- Creating training and sampling clients
- Querying available models
import mint
# Create the service client
service_client = mint.ServiceClient()
# List available models
print("Connected to MinT server!")
capabilities = service_client.get_server_capabilities()
for model in capabilities.supported_models:
print(f" - {model.model_name}")Core Concepts
Before we start training, let’s understand the key concepts:
| Component | Purpose |
|---|---|
TrainingClient | Manages your LoRA adapter and handles training operations |
SamplingClient | Generates text from your trained model |
Datum | A single training example with model_input and loss_fn_inputs |
Training Loop Pattern
for each batch:
forward_backward() # Compute gradients
optim_step() # Update model weightsLoss Functions
- cross_entropy: For SFT - maximizes probability of correct tokens
- importance_sampling: For RL - weights updates by advantage
Stage 1: Supervised Fine-Tuning (SFT)
Goal: Teach the model to solve two-digit multiplication using labeled examples.
How SFT works:
- Show the model input-output pairs
- Model learns to predict the output given the input
- Loss = how surprised the model is by the correct answer
Input: "Question: What is 47 * 83?\nAnswer:"
Output: " 3901"Step 3: Create a Training Client
We’ll use Qwen/Qwen3-0.6B - a small but capable model perfect for learning.
LoRA Parameters:
rank=16: Size of the low-rank matrices (higher = more capacity, slower)train_mlp=True: Train the feed-forward layerstrain_attn=True: Train the attention layerstrain_unembed=True: Train the output projection
from mint import types
BASE_MODEL = "Qwen/Qwen3-0.6B"
# Create a training client with LoRA configuration
training_client = service_client.create_lora_training_client(
base_model=BASE_MODEL,
rank=16, # LoRA rank - controls adapter capacity
train_mlp=True, # Train MLP (feed-forward) layers
train_attn=True, # Train attention layers
train_unembed=True, # Train the output projection
)
# Get the tokenizer - converts text to/from token IDs
tokenizer = training_client.get_tokenizer()Step 4: Prepare Training Data
We need to convert our examples into Datum objects that MinT can process.
Key concept - Weights:
weight=0: Don’t compute loss on this token (the prompt)weight=1: Compute loss on this token (the answer we want to learn)
import random
from mint import types
random.seed(42)
def generate_sft_examples(n=100):
"""Generate two-digit multiplication examples for SFT."""
examples = []
for _ in range(n):
a = random.randint(10, 99)
b = random.randint(10, 99)
examples.append({
"question": f"What is {a} * {b}?",
"answer": str(a * b)
})
return examples
def process_sft_example(example: dict, tokenizer) -> types.Datum:
"""Convert a training example into a Datum for MinT."""
prompt = f"Question: {example['question']}\nAnswer:"
completion = f" {example['answer']}"
# Tokenize prompt and completion separately
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
completion_tokens = tokenizer.encode(completion, add_special_tokens=False)
# Add EOS token so model learns when to stop
completion_tokens = completion_tokens + [tokenizer.eos_token_id]
# Create weights: 0 for prompt, 1 for completion
prompt_weights = [0] * len(prompt_tokens)
completion_weights = [1] * len(completion_tokens)
all_tokens = prompt_tokens + completion_tokens
all_weights = prompt_weights + completion_weights
# For next-token prediction: shift by 1
input_tokens = all_tokens[:-1]
target_tokens = all_tokens[1:]
weights = all_weights[1:]
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
"target_tokens": target_tokens,
"weights": weights
}
)
sft_examples = generate_sft_examples(100)
sft_data = [process_sft_example(ex, tokenizer) for ex in sft_examples]Step 5: Train the Model (SFT)
The training loop:
- forward_backward(): Compute the loss and gradients
- optim_step(): Update model weights using Adam optimizer
NUM_SFT_STEPS = 10
SFT_LEARNING_RATE = 5e-5
for step in range(NUM_SFT_STEPS):
fwdbwd_result = training_client.forward_backward(
data=sft_data,
loss_fn="cross_entropy"
).result()
training_client.optim_step(
types.AdamParams(learning_rate=SFT_LEARNING_RATE)
).result()Step 6: Test the SFT Model
import re
def extract_answer(response: str) -> str | None:
"""Extract the first numeric answer from response."""
numbers = re.findall(r'\d+', response)
return numbers[0] if numbers else None
# Save weights and create sampling client
sft_sampling_client = training_client.save_weights_and_get_sampling_client(
name='arithmetic-sft'
)
# Test
prompt = "Question: What is 23 * 47?\nAnswer:"
prompt_tokens = types.ModelInput.from_ints(tokenizer.encode(prompt))
result = sft_sampling_client.sample(
prompt=prompt_tokens,
num_samples=1,
sampling_params=types.SamplingParams(
max_tokens=16,
temperature=0.0,
stop_token_ids=[tokenizer.eos_token_id]
)
).result()
response = tokenizer.decode(result.sequences[0].tokens)
print(f"Q: What is 23 * 47?")
print(f"A: {response.strip()}")Step 7: Save Checkpoint
Save the SFT model so we can load it and continue training with RL.
sft_checkpoint = training_client.save_state(name="arithmetic-sft-checkpoint").result()
print(f"Checkpoint saved to: {sft_checkpoint.path}")Stage 2: Reinforcement Learning (RL)
Goal: Load the SFT model and refine it using reward signals.
How RL differs from SFT:
- SFT: “Here’s the correct answer, learn it”
- RL: “Try different answers, I’ll tell you if you’re right or wrong”
RL Workflow:
1. Generate multiple responses per problem (exploration)
2. Compute rewards (1.0 = correct, 0.0 = wrong)
3. Compute advantages = reward - mean_reward
4. Train with importance_sampling lossStep 8: Continue Training with RL
rl_training_client = training_client # Continue from SFTStep 9: Define the Reward Function
def generate_rl_problem():
"""Generate harder problems for RL (10-199 vs SFT's 10-99)."""
a = random.randint(10, 199)
b = random.randint(10, 199)
return f"What is {a} * {b}?", str(a * b)
def compute_reward(response: str, correct_answer: str) -> float:
"""Reward function: 1.0 if correct, 0.0 otherwise."""
extracted = extract_answer(response)
return 1.0 if extracted == correct_answer else 0.0Step 10: RL Training Loop
import torch
from mint import TensorData
NUM_RL_STEPS = 10
BATCH_SIZE = 8
GROUP_SIZE = 8
RL_LEARNING_RATE = 2e-5
for step in range(NUM_RL_STEPS):
# Save weights for sampling
sampling_path = rl_training_client.save_weights_for_sampler(
name=f"rl-step-{step}"
).result().path
rl_sampling_client = service_client.create_sampling_client(
model_path=sampling_path,
base_model=BASE_MODEL
)
problems = [generate_rl_problem() for _ in range(BATCH_SIZE)]
training_datums = []
for question, answer in problems:
prompt_text = f"Question: {question}\nAnswer:"
prompt_tokens = tokenizer.encode(prompt_text)
prompt_input = types.ModelInput.from_ints(prompt_tokens)
# Sample multiple responses
sample_result = rl_sampling_client.sample(
prompt=prompt_input,
num_samples=GROUP_SIZE,
sampling_params=types.SamplingParams(
max_tokens=16,
temperature=0.7,
stop_token_ids=[tokenizer.eos_token_id]
)
).result()
# Compute rewards and advantages
group_rewards = []
for seq in sample_result.sequences:
response_text = tokenizer.decode(seq.tokens)
reward = compute_reward(response_text, answer)
group_rewards.append(reward)
mean_reward = sum(group_rewards) / len(group_rewards)
advantages = [r - mean_reward for r in group_rewards]
# Create training datums with advantages
for seq, adv in zip(sample_result.sequences, advantages):
if len(seq.tokens) == 0 or adv == 0:
continue
full_tokens = prompt_tokens + list(seq.tokens)
input_tokens = full_tokens[:-1]
target_tokens = full_tokens[1:]
weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(seq.tokens)
logprobs = [0.0] * (len(prompt_tokens) - 1) + list(seq.logprobs or [0.0] * len(seq.tokens))
full_advantages = [0.0] * (len(prompt_tokens) - 1) + [adv] * len(seq.tokens)
datum = types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
"target_tokens": TensorData.from_torch(torch.tensor(target_tokens, dtype=torch.int64)),
"weights": TensorData.from_torch(torch.tensor(weights, dtype=torch.float32)),
"logprobs": TensorData.from_torch(torch.tensor(logprobs, dtype=torch.float32)),
"advantages": TensorData.from_torch(torch.tensor(full_advantages, dtype=torch.float32)),
},
)
training_datums.append(datum)
# Train
if training_datums:
rl_training_client.forward_backward(
training_datums,
loss_fn="importance_sampling"
).result()
rl_training_client.optim_step(
types.AdamParams(learning_rate=RL_LEARNING_RATE)
).result()Step 11: Test the Final Model
final_path = rl_training_client.save_weights_for_sampler(name="arithmetic-rl-final").result().path
final_client = service_client.create_sampling_client(
model_path=final_path,
base_model=BASE_MODEL
)
# Test on harder problems
test_problems = [
("What is 123 * 45?", "5535"),
("What is 67 * 189?", "12663"),
]
for question, correct in test_problems:
prompt = f"Question: {question}\nAnswer:"
prompt_input = types.ModelInput.from_ints(tokenizer.encode(prompt))
result = final_client.sample(
prompt=prompt_input,
num_samples=1,
sampling_params=types.SamplingParams(
max_tokens=16,
temperature=0.0,
stop_token_ids=[tokenizer.eos_token_id]
)
).result()
response = tokenizer.decode(result.sequences[0].tokens)
extracted = extract_answer(response)
print(f"Q: {question} → A: {response.strip()} (correct: {correct})")Summary
Training Pipeline
| Stage | Method | Loss Function | Purpose |
|---|---|---|---|
| 1 | SFT | cross_entropy | Teach multiplication with labeled examples |
| 2 | RL | importance_sampling | Refine with reward signals |
Key API Methods
# Setup
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(base_model=...)
# Training
training_client.forward_backward(data, loss_fn) # Compute gradients
training_client.optim_step(adam_params) # Update weights
# Checkpointing
checkpoint = training_client.save_state(name)
resumed = service_client.create_training_client_from_state_with_optimizer(checkpoint.path)
# Inference
sampling_client = training_client.save_weights_and_get_sampling_client(name)
sampling_client.sample(prompt, num_samples, sampling_params)