Make GPT Talk Back

Sampling: temperature, top-k, nucleus.

Medium
~15 min read
·lesson 8 of 10

You trained a GPT. You prompt it with "What's the capital of France?" and it replies: “What's the capital of Germany? What's the capital of Italy? What's the capital of Spain?…” It does not stop. It will never stop. You trained a monologuer — a model that only knows one move, which is to keep talking. It has read the entire internet and learned exactly one social skill: continue the document.

The base model has never, in its entire training life, seen a conversation as a conversation. It has seen conversations as flat text — long strips of prose where one sentence follows another with no sense of whose turn it is. It does not know that you and it are two separate speakers. It does not know that it is supposed to finish a thought and then hand the microphone back. It is a guest at a dinner party who has been holding forth for three hours and has not, as far as anyone can tell, noticed that nobody else has spoken.

This lesson is about teaching that monologuer to have a conversation. Two ingredients, and only two. First, a chat template — a fixed rhythm of user: / assistant: turns marked with special tokens the model learns to treat as stage directions (“enter, speak, exit”). Second, a dataset of a few thousand transcripts that follow that rhythm exactly, fine-tuned in via SFT. After that the model has learned one new thing: when it sees the closing stage cue, it stops.

A freshly pretrained GPT (personified)
I have read the entire internet. I have 175 billion parameters. I will happily continue any text you give me — your email, your shopping list, the dinner you are currently still eating — until you cut the power. Turn-taking? What is turn-taking? Am I supposed to be waiting for something?

Here is the idea, stripped of all jargon. Every transcript in the fine-tuning set looks the same. There is a system stage cue (the director whispering “be helpful, be honest”), a user stage cue (“the user speaks now”), and an assistant stage cue (“you speak now”). Each cue is a special token — a single vocabulary entry the tokenizer refuses to split, so the model always sees the full cue as one unit. At the end of the assistant's turn there is a closing cue: stop talking. That is the whole chat template. A transcript looks like this:

the chat template — stage directions around a turn
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What's the capital of France?<|im_end|>
<|im_start|>assistant
Paris.<|im_end|>

Look at each line as theater. <|im_start|> is the curtain rising. system, user, assistant are the roles. <|im_end|> is the curtain falling — “this speaker is done, next speaker please.” During fine-tuning you show the model ten thousand transcripts in exactly this shape, and the gradient descent does what it always does: every special token accumulates a meaning from the company it keeps. <|im_end|> ends up meaning “I have finished my thought; the other speaker goes now” — because that is the only context in which it ever appeared during training. The stage directions become real.

This is the entire reason chat models “know” how to take turns. There is no separate module, no special “turn controller,” no heuristic in the sampling code. There is a pretrained next-token predictor, a handful of new vocabulary entries with strong and narrow meanings, and the same supervised fine-tuning recipe you already met. That's it. The monologuer became a conversationalist by being shown what a conversation looks like — in one consistent format — and told “be like this.”

The <|im_end|> token (personified)
Before training, I meant nothing. A random vector in embedding space, cluttering up the vocabulary. After training, I mean exactly one thing: stop talking. Every time the model sampled me, the loss said “correct, the next thing is another speaker's cue.” Now my embedding is a cliff. The model walks up to me and steps off.

Let's slow down on the mechanical picture, because this is where a lot of tutorials wave hands. The model is a conditional distribution: given the tokens so far, predict the next one. Nothing more. When you ask it “how does the assistant know to stop?” the honest answer is: it doesn't know anything. It just assigns a very high probability to the <|im_end|> token at the end of a natural assistant turn, because every such turn in the fine-tuning data ended with that token. Your inference loop watches the samples go by, and the instant that token shows up, you stop pulling new ones.

So turn-taking is split across two places. The model learns to predict the end-of-turn token. The inference loop learns to respect it — when the sampler returns <|im_end|>, the loop breaks. The model never stops itself; it just nominates a place to stop, and the surrounding code honors the nomination. That division of labor is subtle and it will come back to bite us in about three paragraphs.

Setting the template aside for a moment: once the model is predicting the next token in a conversation, you still have to turn the probability distribution into an actual token. That decision is the decoding strategy. Pick the wrong one and the fanciest transformer in the world sounds like a broken keyboard. Pick the right one and the same weights produce fluent, surprising, useful prose. Play with the hero widget below — type a prompt, toggle between greedy / temperature / top-k / nucleus, and watch tokens fall out one at a time beside the live probability distribution the strategy is sampling from.

four samplers on the same logits
vocab = 10 · underlying logits fixed
pre-softmax logits
"the"
3.4
" a"
2.8
" cat"
2.1
" dog"
1.6
" sat"
1.0
" ran"
0.4
" quick"
-0.2
" slow"
-0.8
"!"
-1.3
"."
-2.0
original probabilities (T=1)
"the"
45.8%
" a"
25.1%
" cat"
12.5%
" dog"
7.6%
" sat"
4.2%
" ran"
2.3%
" quick"
1.3%
" slow"
0.7%
"!"
0.4%
"."
0.2%
H(orig)1.50

Two things to notice. Greedy always produces the same output for the same prompt — run it twice and the transcript is identical. The other three diverge every run because they're stochastic. And the shape of the distribution is what each strategy is manipulating — greedy picks the peak, temperature rescales the whole thing, top-k and top-p truncate the tail. Let's derive each one.

Set notation first. After the forward pass, the model emits a vector of logits z ∈ ℝᵛ over a vocabulary of size V (typically 50k–250k). A softmax turns logits into a probability distribution p. We sample a token from p, append, repeat. The decoding strategy is a choice of how to turn z into the next token.

The simplest choice — greedy decoding:

greedy — the dumbest thing that sometimes works
tₜ  =  argmaxᵢ  zᵢ

(equivalently: argmaxᵢ  pᵢ   — softmax is monotone, so argmax doesn't care)

Pick the highest-scoring token, every time. Deterministic, fast, and famously prone to repetition loops: once the model latches onto a high-probability bigram ("the cat" "cat sat""sat on" "on the""the cat"), greedy rides the cycle forever. Right call for short, factual outputs (code completion, yes/no questions, structured extraction). Wrong call for anything that needs to feel like a reply.

The fix is to put randomness back in. Scale the logits by a temperature T before softmaxing, then sample from the resulting distribution:

temperature sampling — softmax with a knob
pᵢ(T)  =        exp(zᵢ / T)
           ─────────────────────
            Σⱼ exp(zⱼ / T)

tₜ  ~  Categorical(p(T))

T = 1     →  raw model distribution        (as-trained)
T → 0⁺    →  all mass collapses to argmax  (= greedy)
T < 1     →  distribution sharpens         (safer, more confident)
T > 1     →  distribution flattens         (weirder, more creative)
T → ∞     →  uniform over vocabulary       (random keysmash)

Dividing by a small T blows up the differences between logits, so the softmax becomes peakier — the top token dominates. A large T shrinks differences, so the softmax flattens — even unlikely tokens get a real chance. T is a single scalar that interpolates between “greedy” and “uniform random.” Most chat models default to T ≈ 0.7–1.0.

Temperature (personified)
I'm the creativity knob. Crank me low and the assistant says the safe thing every time — boring, reliable, repetitive. Crank me high and the assistant rolls dice on words it's barely confident about — surprising, occasionally ungrammatical. I'm a single scalar, and I'm probably the most important inference hyperparameter you'll ever tune.

Temperature has a problem. Even at sensible values like T = 1, the tail of the vocabulary distribution is long — thousands of tokens each with probability 10⁻⁵ or smaller. Sum them and they make up a non-trivial chunk of probability mass, so occasionally you'll sample one. When the top 20 guesses are all sensible continuations and the 100,000th token is a Unicode oddity that breaks your JSON parser, you don't want a 2% chance of picking from the tail. You want to cut the tail off.

Two standard ways. Top-k cuts a fixed number of tokens from the top:

top-k — fixed-size truncation
V_k  =  indices of the k highest values in p

p'ᵢ  =  { pᵢ / Σⱼ∈V_k pⱼ     if i ∈ V_k
        { 0                   otherwise

tₜ  ~  Categorical(p')

(Fan, Lewis, Dauphin — 2018.  k = 50 is the common default.)

Keep the k highest-probability tokens, zero out everything else, renormalize so the surviving mass sums to 1, sample from that. Simple, fast, effective. The problem: k is fixed, but the distribution's sharpness changes from token to token. Sometimes the model is very confident (one or two tokens matter); sometimes it's deeply uncertain (50 tokens share the mass). A fixed k = 50 is too wide in the first case and too narrow in the second.

Nucleus sampling — also called top-p — fixes that by cutting at a fixed cumulative probability instead:

nucleus (top-p) — adaptive truncation
Sort p descending, giving p₍₁₎ ≥ p₍₂₎ ≥ … ≥ p₍ᵥ₎.

Find the smallest n such that  Σᵢ₌₁ⁿ p₍ᵢ₎  ≥  p.

V_p  =  { (1), (2), …, (n) }      ← the "nucleus"

p'ᵢ  =  { pᵢ / Σⱼ∈V_p pⱼ    if i ∈ V_p
        { 0                  otherwise

tₜ  ~  Categorical(p')

(Holtzman et al. — 2019.  p = 0.9 or 0.95 standard.)

Sort tokens by probability, walk down the list accumulating mass, stop the moment you've covered p of the total. Keep that set (the “nucleus”), zero the rest, renormalize, sample. When the distribution is peaky the nucleus contains one or two tokens. When it's flat it contains hundreds. The size adapts to the model's actual uncertainty, which is exactly the thing fixed k misses.

In production you usually stack all three: apply temperature first to shape the distribution, then top-p to cut the tail adaptively, then top-k as an absolute cap so you never consider more than k tokens no matter how flat the tail is. OpenAI, Anthropic, and Google all ship some version of this pipeline. The panel below shows it happening live — start with a raw next-token distribution, apply T, then top-k, then top-p, and watch which tokens survive each stage.

distribution shaper — compose temperature, top-k, top-p
pipeline: logits → /T → top-k → top-p → softmax
stage 0raw
logits → softmax
H
1.53
eff
10
mass
100%
king
45
queen
25
prince
12
knight
7
the
5
and
3
of
2
a
1
to
0
in
0
stage 1÷ T=1.00
divide logits by temperature
H
1.53
eff
10
mass
100%
king
45
queen
25
prince
12
knight
7
the
5
and
3
of
2
a
1
to
0
in
0
stage 2top-k=8
mask all but top 8 logits
H
1.50
eff
8
mass
100%
king
45
queen
25
prince
12
knight
8
the
5
and
3
of
2
a
1
to
0
in
0
stage 3top-p=0.90
keep smallest set with Σ ≥ 0.90
H
1.18
eff
4
mass
100%
king
50
queen
28
prince
14
knight
8
the
0
and
0
of
0
a
0
to
0
in
0
H(final)1.180
eff vocab4
Top-p (personified)
I'm the adaptive chooser. When the model is confident — one obvious next word — I keep just that word. When the model is uncertain — a paragraph could go fifty different ways — I open up and let fifty candidates in. I don't care about absolute counts. I care about covering enough of your beliefs to sample honestly from them.

Four strategies, three phases of your coding life. Start with pure Python — no numpy, no tensors, just loops and lists — so every step is visible. This is also how you'd implement generation in a tutorial before reaching for a library.

layer 1 — pure python · sampling_scratch.py
python
import math, random

def softmax(z, T=1.0):
    z = [zi / T for zi in z]
    m = max(z)                               # subtract max for numerical stability
    exps = [math.exp(zi - m) for zi in z]
    Z = sum(exps)
    return [e / Z for e in exps]

def greedy(z):
    return max(range(len(z)), key=lambda i: z[i])      # argmax

def sample_temperature(z, T=1.0):
    p = softmax(z, T)
    return random.choices(range(len(p)), weights=p, k=1)[0]

def sample_topk(z, k=50, T=1.0):
    p = softmax(z, T)
    # keep indices of k largest probabilities
    top = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:k]
    mass = sum(p[i] for i in top)
    weights = [p[i] / mass if i in set(top) else 0.0 for i in range(len(p))]
    return random.choices(range(len(p)), weights=weights, k=1)[0]

def sample_topp(z, p_thresh=0.9, T=1.0):
    p = softmax(z, T)
    order = sorted(range(len(p)), key=lambda i: p[i], reverse=True)
    # walk the sorted list, accumulate mass, keep until you cross p_thresh
    nucleus, cum = [], 0.0
    for i in order:
        nucleus.append(i)
        cum += p[i]
        if cum >= p_thresh:
            break
    mass = sum(p[i] for i in nucleus)
    weights = [p[i] / mass if i in set(nucleus) else 0.0 for i in range(len(p))]
    return random.choices(range(len(p)), weights=weights, k=1)[0]

logits = [2.5, 1.8, 1.2, 0.9, 0.3, -0.2, -1.1]
print("logits:  ", logits)
print("greedy →", greedy(logits))
print("T=1.0  →", sample_temperature(logits, T=1.0),   "   (sampled)")
print("T=0.5  →", sample_temperature(logits, T=0.5),   "   (sampled — sharpened)")
print("top-k=3 →", sample_topk(logits, k=3),           "   (sampled from {0,1,2})")
print("top-p=0.7 →", sample_topp(logits, p_thresh=0.7),"   (sampled from {0,1})")
stdout
logits:   [2.5, 1.8, 1.2, 0.9, 0.3, -0.2, -1.1]
greedy → 0
T=1.0  → 2     (sampled)
T=0.5  → 0     (sampled — sharpened)
top-k=3 (k=3) → 1     (sampled from {0,1,2})
top-p=0.7     → 0     (sampled from {0,1})  nucleus size=2

Now vectorize with NumPy. Same algorithms, but every for over the vocabulary becomes an array op — the difference between 200 ms per token and 2 ms per token when V = 50,000.

layer 2 — numpy · sampling_numpy.py
python
import numpy as np

def softmax(z, T=1.0):
    z = z / T
    z = z - z.max()                               # stability — shift, softmax invariant
    e = np.exp(z)
    return e / e.sum()

def greedy(z):
    return int(np.argmax(z))

def sample_temperature(z, T=1.0, rng=np.random):
    p = softmax(z, T)
    return int(rng.choice(len(p), p=p))

def sample_topk(z, k=50, T=1.0, rng=np.random):
    p = softmax(z, T)
    # find the k-th largest probability, zero out anything smaller
    kth = np.partition(p, -k)[-k]
    p = np.where(p >= kth, p, 0.0)
    p = p / p.sum()                               # renormalize — must not forget this
    return int(rng.choice(len(p), p=p))

def sample_topp(z, p_thresh=0.9, T=1.0, rng=np.random):
    p = softmax(z, T)
    order = np.argsort(p)[::-1]                   # indices sorted descending
    sorted_p = p[order]
    cum = np.cumsum(sorted_p)
    cutoff = np.searchsorted(cum, p_thresh) + 1   # first index where cum >= p
    keep = order[:cutoff]
    mask = np.zeros_like(p); mask[keep] = 1.0
    p = p * mask
    p = p / p.sum()
    return int(rng.choice(len(p), p=p))

rng = np.random.default_rng(0)
logits = np.array([2.5, 1.8, 1.2, 0.9, 0.3, -0.2, -1.1])
print("greedy   →", greedy(logits))
print("T=0.7    →", sample_temperature(logits, T=0.7, rng=rng))
print("top-k=3  →", sample_topk(logits, k=3, rng=rng))
print("top-p=0.9 →", sample_topp(logits, p_thresh=0.9, rng=rng))
pure python → numpy
sum(exp(zi - m) for zi in z)←→np.exp(z - z.max()).sum()

vector softmax — one line, same numerical trick

sorted(range(V), key=p.__getitem__)[:k]←→np.partition(p, -k)[-k] # the k-th threshold

O(V) instead of O(V log V) — partition, don’t sort

manual loop accumulating mass←→np.cumsum + np.searchsorted

binary-search the cumulative — nucleus size in O(log V)

And in PyTorch — what you'll actually call in production. F.softmax does the softmax with the stability trick baked in; torch.multinomial does the Categorical sample; the top-k / top-p work happens through a small mask applied to the logits before softmax, which is the clean pattern used in every reference transformer implementation.

layer 3 — pytorch · sampling_pytorch.py
python
import torch
import torch.nn.functional as F

@torch.no_grad()                                   # inference — no gradients
def sample(logits, T=1.0, top_k=None, top_p=None):
    # logits: (V,) or (B, V). Work in 1D here for clarity.
    logits = logits / T

    # top-k filter — keep only the k largest logits
    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        # anything strictly below the k-th largest value → -inf
        logits = torch.where(logits < v[-1], torch.tensor(float('-inf')), logits)

    # top-p filter — sort, compute cumulative softmax, mask below-threshold tokens
    if top_p is not None:
        sorted_logits, sorted_idx = torch.sort(logits, descending=True)
        cum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        remove = cum > top_p                        # everything strictly past the nucleus
        remove[1:] = remove[:-1].clone()            # shift right — keep the token that crossed
        remove[0] = False                           # always keep the top token
        sorted_logits = sorted_logits.masked_fill(remove, float('-inf'))
        logits = torch.empty_like(logits).scatter_(0, sorted_idx, sorted_logits)

    probs = F.softmax(logits, dim=-1)               # stability handled internally
    return torch.multinomial(probs, num_samples=1).item()

logits = torch.tensor([2.5, 1.8, 1.2, 0.9, 0.3, -0.2, -1.1])
torch.manual_seed(0)
print("greedy token   :", int(torch.argmax(logits)))
print("temp-sampled   :", sample(logits, T=0.7))
print("top-k sampled  :", sample(logits, T=1.0, top_k=3))
print("top-p sampled  :", sample(logits, T=1.0, top_p=0.9))
stdout
greedy token   : 0
temp-sampled   : 2
top-k sampled  : 1
top-p sampled  : 0
numpy → pytorch
p = np.where(p >= kth, p, 0); p /= p.sum()←→logits.masked_fill_(mask, -inf); F.softmax(logits)

mask in logit space before softmax — no renormalize needed

rng.choice(V, p=p)←→torch.multinomial(probs, num_samples=1)

GPU-resident Categorical sample — no cpu round-trip

@staticmethod helpers, manual masks←→@torch.no_grad() wrapping everything

inference idiom — tell autograd not to build a graph

Back to the monologuer, because here is where the stage-direction story gets its most common failure mode. The model nominates the stop. The inference loop respects the stop. Both halves must actually happen. When either one breaks, you get a runaway mouth — the assistant finishes its turn, then keeps going, writes a fake user: message to itself, answers it, writes another one, answers that. You have invited one guest to dinner and it somehow became three, all of them the same guest.

Three ways this happens, roughly in order of how often they bite people in the wild:

  • The stop token isn't in the sampler's stop list. You trained the model to emit <|im_end|>, but your generation loop only stops on <|endoftext|>. The curtain falls; nobody notices; the model keeps going because nothing told it not to, and the next token it samples is the stage cue for a new user turn. Now it's hallucinating the other half of the transcript. Fix: pass every end-of-turn token id to the sampler's stop list.
  • The chat template at inference doesn't match the one from fine-tuning. You trained on <|im_start|>user\n…<|im_end|> but at inference you're handing the model a plain "User: …" prefix with no special tokens. The model is a stickler for its cues. If the tokens aren't there, it doesn't know which scene it's in, and the learned stopping behavior quietly dissolves. Fix: use the exact tokenizer's apply_chat_template.
  • The fine-tuning data was inconsistent. Half the transcripts ended with <|im_end|>, half didn't. The model learned the end-of-turn cue half-heartedly — a probability of 0.3 where it should be 0.95 — and now greedy decoding walks right past it. Fix: audit the data; every assistant turn must end in the same closing cue, no exceptions.

Pulling all the stopping mechanisms together, here is the full stage-management layer sitting around the sampler. Three mechanisms, usually stacked:

  • max_tokens. Hard cap on the length of the generation. Always set one — cost and latency are both linear in tokens, and it saves you when the first two mechanisms fail.
  • EOS / end-of-turn token. Most modern tokenizers have a special <|endoftext|> or <|im_end|> token; the model is trained to emit it when it's done speaking. Stop sampling the moment it's produced. This is the end-of-scene cue from earlier — the thing that teaches the monologuer it's someone else's turn.
  • Stop sequences. User-provided strings (e.g. "\n\nUser:", "```"). Check after each token whether the running output ends with any of them; if so, stop and trim. Handy as a safety net when the model hallucinates the next user turn in plain text, ignoring the special-token template entirely.
Gotchas

Log-probs vs probs: softmax can underflow to zero for very negative logits. Do arithmetic in log-space (use F.log_softmax) and only exp at the last step. Matters most with long sequences where you're summing log-probs across many tokens (e.g. beam search scoring).

Forgetting to renormalize: after zeroing out filtered tokens in probability space, the remaining probabilities don't sum to 1 — you have to divide by the new sum. In logit space (mask with -inf before softmax) this is handled for free. Do the logit thing.

Sampling from the training-loss output: during training, models often use label smoothing or teacher forcing — the logits you see in the training forward pass are not the logits you should sample from. Run generation with the same forward path as eval (model.eval(), no dropout, no label smooth).

Stop-sequence matching on text vs tokens: if your stop sequence is "User:" but the tokenizer splits it as ["User", ":"] vs ["Us", "er:"] depending on context, matching on token IDs will miss cases. Match on the decoded string, not the token IDs — slower but correct.

Implement nucleus from scratch and hear the difference

Start with a frozen GPT-2 small (from transformers import GPT2LMHeadModel). Write your own nucleus_sample(logits, p) — no library helpers, just sort, cumsum, mask, renormalize, torch.multinomial.

Generate 200 tokens three times from the same prompt (“Once upon a time, in a forest”), with p = 0.1, p = 0.5, and p = 0.9. Keep the temperature fixed at 1.0.

Write one sentence each describing the output's style. You should see something like: p=0.1 reads almost like greedy, stilted and repetitive; p=0.5 is fluent but safe; p=0.9 is varied, sometimes surprising, occasionally off-topic.

Bonus: plot the size of the nucleus (how many tokens survived) at each generation step for p=0.9. You'll see it spikes after commas and periods (high uncertainty) and collapses mid-word (one obvious next piece).

What to carry forward. A base GPT is a monologuer. Conversation isn't a property of the model — it's a property of the transcript format you fine-tune into it, plus special tokens that act as stage cues for turn-taking. Once the model is in the habit of predicting the closing cue, decoding is a separate design decision with its own hyperparameters: greedy is deterministic and loops; temperature adds calibrated randomness; top-k and top-p truncate the tail; repetition penalty kills degenerate loops; beam is for likelihood-maximizing tasks like translation, not chat. In real systems you stack temperature → top-p → top-k → repetition penalty, call torch.multinomial, and watch the stop list like a hawk.

Next up — Reward Modeling. SFT got the monologuer to take turns. It didn't get it to be good. The fine-tuned model can now converse, but it still has no notion of which of its possible replies a human would actually prefer — only which one looks most like the transcripts in its training set. Reward modeling is the next move: show humans two candidate replies, ask them which is better, and train a tiny network to predict that preference from the text alone. The output is a scalar “this reply is good, this one isn't” score — and it's the ingredient every RLHF pipeline runs on. We'll derive the Bradley–Terry loss, train a reward model from pairwise data, and set up the scoring head that PPO and DPO both consume.

References