Supervised Fine-Tuning
Turn a base model into an instruction-follower.
Picture a fluent polyglot with no manners. They speak every language on the internet — pretraining on a trillion tokens will do that — and yet you cannot get through dinner with them. Ask a question and they finish the question. Say hello and they recite an email template. The grammar is flawless. The vocabulary is encyclopedic. What's missing is school — the other kind of school. The kind that teaches which fork to use.
Feed a base model the first half of a Wikipedia article and it will continue with plausible Wikipedia. Feed it "Write me a haiku about Kubernetes" and you will get… another instruction. Because on the internet, that's what usually follows a line like that: a forum post asking the same question, a tutorial title, another prompt. Our polyglot has read every English sentence ever written and learned English, not assistance.
Supervised fine-tuning — SFT — is etiquette school for the fluent. You take a few thousand to a million (prompt, response) pairs of polite exchanges written by humans (or by a strong model), format them with role markers, and train the model to produce the response token-by-token given the prompt. The loss looks almost exactly like pretraining's next-token prediction cross-entropy. Almost.
The one difference is the thing that matters. You don't train it to predict the prompt. You only train it to predict the reply. You're not teaching the polyglot new words — you're teaching them what to say versus what to listen to. Everything interesting in this lesson lives inside that sentence.
I was trained on a trillion tokens of internet. I can do your taxes, roast your haiku, write a Python REPL. But I will not, unprompted, know to answer a question. Someone has to show me the shape of a reply. Someone has to school me in when the human has stopped talking.
Here's what a single SFT training example looks like in its rawest form — one polite exchange, two turns, a question and the shape of a good answer. This is an etiquette flashcard.
Human: What's the capital of New Zealand? Assistant: Wellington — it's on the southern tip of the North Island. ◆ prompt tokens ..... 13 ← the model sees these, does NOT learn them ◆ response tokens ... 18 ← the model sees these AND trains on them ◆ total sequence .... 31 ← what goes into the transformer in one forward pass
That's the whole object. Two roles, a question, an answer, and a mental note about which tokens the model is supposed to learn vs. which ones it's only supposed to read. Every real SFT dataset is a pile of these polite exchanges, formatted consistently. The formatting is the interesting bit, because the polyglot has to learn where its own turn begins — which means there must be an unambiguous marker. Think of it as the napkin on the lap: a small ritual that tells everyone dinner has started.
<|im_start|>user What is the capital of France?<|im_end|> <|im_start|>assistant The capital of France is Paris.<|im_end|>
Toggle between the three formats above. The raw view shows the data as-is. The instruction-format view wraps it in Alpaca-style headers (### Instruction: / ### Response:). The chat-template view uses role tokens like <|im_start|>user and <|im_end|> — special single-token markers the tokenizer emits once per turn. Different model families use different conventions. Llama-2 has its own [INST] tags; ChatML uses <|im_start|>; Vicuna uses bare USER:/ASSISTANT:. None of them are “right” — they're just the table manners this particular finishing school decided to enforce. Pick a convention and stick to it; the polyglot only knows the etiquette you teach them.
I am a naming convention dressed up as infrastructure. I decide where your turn ends and the model's begins. Train the model with me one way and serve it another, and you will get the most confused-sounding assistant of your career. Consistency is the whole job.
Now the critical piece. The sequence has 31 tokens. A naive next-token loss would train the model to predict all of them — including the prompt. That would mean optimizing the polyglot to sound like a human asking questions, which is the opposite of the goal. You are not running a school for interrogators. So we mask.
Assign a label vector y the same length as the input. For positions inside the prompt, write -100 — PyTorch's CrossEntropyLoss treats -100 as “ignore this position” and contributes nothing to the gradient. For positions inside the response, write the true next-token id. The loss averages only over the response positions:
N
ℒ_SFT = − ───────── ∑ m_t · log P(x_t | x_<t ; θ)
∑_t m_t t=1
where m_t = 1 if token t is part of the response
m_t = 0 if token t is part of the prompt (label = −100)It's the usual causal-LM cross-entropy with one extra indicator m_t. The denominator — the count of unmasked tokens — keeps the loss at the same scale whether your prompts are long or short. That matters because if you average over all tokens, a conversation with a three-paragraph prompt and a one-sentence reply would contribute almost nothing to the gradient. Only the polite reply gets graded by the etiquette teacher; the guest's question is read aloud for context and then ignored.
Every token is colored by whether it contributes to the loss. Scroll across the sequence: the prompt greys out, the response lights up. That mask is the thing separating SFT from raw continued pretraining. Take it away and you're teaching the polyglot to sound like whoever wrote the prompt — you are fine-tuning on the wrong half of the conversation. Etiquette school where the student practices the teacher's lines.
I am the difference between “train on all of it” and “train on the reply.” I am a boolean vector the same length as your tokens. I cost nothing to compute. I am why your fine-tune sounds like an assistant instead of an echo chamber. Forget me and you will not know anything is wrong until the evaluations come back strange.
The optimization is less heroic than you'd think. SFT is a weekend finishing course, not four years of language immersion. Typical hyperparameters:
- Dataset: 10k–1M examples. Pretraining used trillions of tokens — SFT is three to six orders of magnitude smaller. You already spoke the language; you just need manners.
- Epochs: 1–3. Go past 3 and the polyglot memorizes your exact dinner party lines instead of generalizing the etiquette; politeness collapses into parroting.
- Learning rate:
2e-5is the standard, roughly 10× lower than the end-of-pretraining LR. You're nudging a finished model, not building one. - Optimizer: AdamW. Warmup ratio ~3%, cosine decay to zero.
- Batch size: effective batch ~128 (gradient accumulation does most of the work on a single-node setup).
- Context length: long enough to hold the longest conversation in your dataset. Padding is wasteful; packing multiple short examples into one sequence helps.
Compared to pretraining a 7B model (thousands of GPU-years, petabytes of text) a reasonable SFT run is hours on a single 8×A100 node. That's most of why SFT caught on as a lab technique — anyone with a node can send their fluent polyglot to etiquette school over a long weekend, and the impact on output quality is dramatic.
SFT has one persistent failure mode. You're updating every weight in the network — the same weights that encode everything the base model learned during pretraining. If you overfit, or train for too long, or train on a narrow dataset, you can erase pretrained capabilities. Your chat-tuned model forgets how to do arithmetic. Your code-tuned model forgets French. Etiquette school, pushed too hard, can make the polyglot forget how to speak. This is catastrophic forgetting and it is embarrassingly easy to induce.
The three standard defenses:
- Mix in pretraining data. Every batch: some SFT, some unchanged pretraining text. Keeps the distribution anchored.
- Low learning rate, few epochs. The numbers above (
2e-5, 1–3 epochs) exist to prevent this. - Parameter-efficient fine-tuning. Freeze the base weights entirely; only train a small set of adapters (LoRA, QLoRA). This is the next lesson and it is the standard modern move.
Three layers, one job: take a (prompt, response) pair, turn it into a training batch with the right mask, and get a loss. Pure Python does the formatting. NumPy does the tokenization and the mask. PyTorch + HuggingFace + TRL does the full training loop — because in production you will not write any of this by hand.
# One example. No tokenizer, no tensors — just the text contract.
def format_alpaca(prompt: str, response: str) -> dict:
formatted = (
f"### Instruction:\n{prompt}\n\n"
f"### Response:\n{response}"
)
# Track where the response starts, so downstream we know what to mask.
prompt_part = f"### Instruction:\n{prompt}\n\n### Response:\n"
return {
"text": formatted,
"prompt_len": len(prompt_part),
"response_len": len(response),
}
ex = format_alpaca(
"What's the capital of New Zealand?",
"Wellington — it's on the southern tip of the North Island.",
)
print("--- formatted ---")
print(ex["text"])
print(f"\nprompt_len = {ex['prompt_len']} | response_len = {ex['response_len']}")--- formatted --- ### Instruction: What's the capital of New Zealand? ### Response: Wellington — it's on the southern tip of the North Island. prompt_len = 45 | response_len = 57
Move to NumPy. Tokenize once, build the labels vector with -100 in the prompt region, and that is the object your trainer wants.
import numpy as np
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
prompt = "Human: What's the capital of New Zealand?\nAssistant: "
response = "Wellington — it's on the southern tip of the North Island."
# Tokenize the two halves separately so we know their lengths exactly.
prompt_ids = tok(prompt, add_special_tokens=False)["input_ids"]
response_ids = tok(response + tok.eos_token, add_special_tokens=False)["input_ids"]
input_ids = np.array(prompt_ids + response_ids, dtype=np.int64)
# Labels are a copy of input_ids with prompt positions zeroed out via -100.
labels = input_ids.copy()
labels[: len(prompt_ids)] = -100 # ignore prompt in the loss
print("input_ids shape:", input_ids.shape)
print("labels shape:", labels.shape)
print(f"prompt mask positions: {len(prompt_ids)} "
f"response mask positions: {len(response_ids)}")
print("first 5 labels:", labels[:5])
print("last 5 labels:", labels[-5:])input_ids shape: (31,) labels shape: (31,) prompt mask positions: 13 response mask positions: 18 first 5 labels: [-100 -100 -100 -100 -100] last 5 labels: [ 286 6255 5373 29889 2]
formatted = f"...{prompt}...{response}"←→input_ids = tok(prompt_part) + tok(response_part)— two tokenizer calls so you know the prompt boundary
prompt_len (chars)←→len(prompt_ids) # tokens— tokens, not characters — that is the unit the loss operates on
conceptually: "ignore the prompt"←→labels[:len(prompt_ids)] = -100— the entire trick in one line
Layer 3 — the thing you actually run. HuggingFace transformers ships the model, the tokenizer, and the chat template. TRL's SFTTrainer wraps the whole masking / packing / training loop. A full SFT run in about 30 lines.
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer, SFTConfig
MODEL = "meta-llama/Llama-2-7b-hf"
tok = AutoTokenizer.from_pretrained(MODEL)
tok.pad_token = tok.eos_token # Llama-2 ships without a pad token
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Alpaca — 52k (prompt, response) pairs. Clean, small, battle-tested.
ds = load_dataset("tatsu-lab/alpaca", split="train")
def formatting_func(row):
# TRL uses the chat template registered on the tokenizer.
msgs = [
{"role": "user", "content": row["instruction"]},
{"role": "assistant", "content": row["output"]},
]
return tok.apply_chat_template(msgs, tokenize=False)
cfg = SFTConfig(
output_dir = "sft-llama2-alpaca",
num_train_epochs = 3,
per_device_train_batch_size = 4,
gradient_accumulation_steps = 32, # → effective batch 128
learning_rate = 2e-5,
lr_scheduler_type = "cosine",
warmup_ratio = 0.03,
bf16 = True,
logging_steps = 20,
save_steps = 500,
max_seq_length = 2048,
packing = True, # concatenate short examples → fewer wasted tokens
)
trainer = SFTTrainer(
model = model,
tokenizer = tok,
train_dataset = ds,
formatting_func = formatting_func,
args = cfg,
)
trainer.train() # loss masking, chat template, everything — handledlabels[:prompt_len] = -100←→SFTTrainer(formatting_func=...)— TRL computes the response-only mask for you from the chat template
for loop over (ids, labels)←→trainer.train()— packing, padding, collation, AdamW, cosine LR, logging — one call
raw text concat←→tok.apply_chat_template(msgs, tokenize=False)— the canonical way to emit the exact format the model was trained on
Template mismatch train → infer: the most common SFT bug. You train with ChatML (<|im_start|>user) and serve with bare USER: prefixes, or vice-versa. The model never sees the start-of-turn marker it was trained on; it treats your prompt as mid-conversation text; the output is weird. Always use tokenizer.apply_chat_template on both sides.
Forgetting the mask: if you just feed input_ids as both input and labels, you're training the model to predict its own prompt. It will still “work” — loss goes down, eval plausibly improves — but the gradient signal is diluted and the model learns to imitate users as much as assist them.
Overtraining: 10 epochs on your SFT set will give you a model that quotes its training data verbatim, has lost arithmetic, and speaks only in the style of your annotators. 1–3 epochs, low LR, cosine decay. Resist the instinct to train until loss stops going down — on SFT, you want to stop well before it plateaus.
Tokenizing the response without EOS: if you don't append tok.eos_token to the response before tokenizing, the model never learns when to stop generating. At inference it will keep going past the answer, hallucinate a new question, and answer that too. One extra token in the dataset saves a thousand confused user-reports.
Sample 1,000 rows from tatsu-lab/alpaca (random seed, stratified by instruction length if you want to be fancy). Run the layer-3 script above with num_train_epochs=3 and learning_rate=2e-5. On a single 8×A100 node this takes under two hours with QLoRA, roughly a day with full fine-tuning.
Hold out 20 instructions your training data never touched. Generate a response from both the base Llama-2-7B and your SFT model with the same decoding settings (temperature=0.7, top_p=0.9, max_new_tokens=256). Read them side by side.
What you should see: the base model completes the instruction as if it were forum text — often echoing the question, often trailing into tangents. The SFT model answers directly. Neither is smarter than the other in any deep sense; they've just been pointed at different distributions. That's the whole thing SFT does.
What to carry forward. SFT is etiquette school for a fluent polyglot — the first step that turns a base model into something you can actually talk to. The mechanics are a standard causal-LM loss with one indicator variable (the loss mask) that says “only grade the reply.” The dataset is tiny compared to pretraining, the training run is short, and the leverage is enormous. The failure modes are mostly about consistency: chat templates have to match between train and serve, EOS has to be where it belongs, and you have to resist training too long. And the hard ceiling is still the one from the opening — you cannot etiquette-school your way into facts the polyglot never learned.
Next up — LoRA. The 30-line script above updates all 7 billion weights of Llama-2. That's expensive to train, expensive to store (one 14GB checkpoint per task), and the easiest way in the world to induce catastrophic forgetting. LoRA — Low-Rank Adaptation — changes the contract: freeze the base polyglot, train 0.1% of new parameters as a stack of sticky notes over the weights, keep one base model and a dozen tiny adapters around. It is the single most important modern fine-tuning technique and it's a surprisingly small amount of linear algebra.
- [01]Ouyang et al. · NeurIPS 2022
- [02]Zhou et al. · NeurIPS 2023
- [03]Touvron et al. · Meta AI, 2023
- [04]HuggingFace · library, SFTTrainer reference
- [05]Taori et al. · 2023 — the 52k-example dataset most SFT tutorials use