Vanishing Gradient Problem

Why long sequences kill plain RNNs — analytically.

Hard
~15 min read
·lesson 3 of 5

Play the telephone game with fifty kids. The first one hears a sentence — “the fox crossed the river at dawn” — and whispers it to the second, who whispers what they heard to the third, and so on down the line. By kid twenty the signal is mush. By kid fifty the last child is earnestly relaying something about a fax machine and a raven. No single whisper was catastrophically wrong. Each one just dropped a little — a consonant here, a vowel there — and fifty little drops multiplied into a message that has nothing to do with the original.

That is exactly what gradient signals do in a vanilla RNNs when you try to train them on long sequences. Between roughly 1991 and 1997 this was the wall the whole sub-field could not push through. You could teach an RNN to remember a bit across eight time steps. Push that to twenty and the loss curve went flat. Push to fifty and the network learned absolutely nothing about the start of the sequence — as if the first thirty tokens had never been shown. Each step backward in BPTT is a whisper, and the gradient at the end of the chain is what the last kid heard.

This wasn't a bug. It wasn't a learning-rate issue. It was arithmetic — the same arithmetic that kills deep sigmoid stacks, now reinforced by a vicious new multiplier: the same weight matrix, applied over and over. Sepp Hochreiter diagnosed it in a 1991 German-language diploma thesis that nobody outside his advisor read for three years. Bengio et al. re-derived it in 1994. Pascanu et al. nailed it down with spectral analysis in 2013. And then LSTM came out, and suddenly the telephone chain had a shortcut.

This lesson is that wall. We'll write the gradient chain, compute its spectral radius, watch it collapse to zero on a real RNN, and understand exactly why gating — not clipping, not better initialization, gating — was the only way out.

Vanishing gradient (personified)
I am not a training instability. I am not a numerical glitch you can fix with a learning-rate schedule. I am what happens when you line up twenty kids and ask them to whisper the same sentence down the chain. There is no hyperparameter for that.

A vanilla RNN updates its hidden state with the same recurrence at every step:

h_t = tanh(W_h · h_{t-1} + W_x · x_t + b).

To train it, we need ∂L / ∂h_0 — how the loss at the end of the sequence depends on the hidden state at the very start. Backprop through time says: chain together the Jacobians at every step. Same telephone chain, now written as a product.

backprop through time — the product that decides everything
∂h_T        T
────   =   ∏   ∂h_t / ∂h_{t-1}
∂h_0       t=1

        =   ∏   W_hᵀ  ·  diag( tanh'( W_h · h_{t-1} + W_x · x_t + b ) )
            t

        ≈   ( W_hᵀ · D )ᵀ        where D is the average diagonal derivative

Read the product literally. Each factor W_hᵀ · D is one kid in the chain — one whisper from step t to step t−1. The gradient from time T all the way back to time 0 is a power of that one matrix — roughly (W_hᵀ · D)^T. And the fate of any matrix power is decided by one number: its spectral radius — the magnitude of its largest eigenvalue, the amount by which each whisper gets quieter (or louder) on its way down the line.

  • If ρ(W_hᵀ · D) < 1 — the product shrinks exponentially. Every kid whispers a little softer than they heard; by kid fifty the signal is dust. Vanishing.
  • If ρ(W_hᵀ · D) > 1 — the product blows up exponentially. Every kid has a megaphone; by kid fifty the gradient arrives as a cannon-shell and the optimiser takes a step into the void. Exploding.
  • If ρ(W_hᵀ · D) = 1 — you have won the initialisation lottery. Enjoy your training run. It will end when you look at it funny.

Here is the whisper fading, live. We train a tiny RNN on a task where it has to remember a bit from the start of a sequence and output it at the end. We plot the norm of ∂L/∂h_0 as a function of sequence length — how loud the last kid's message is, as a function of how many kids are in line.

gradient magnitude vs. sequence length — log collapse in action
|∂L/∂h_0| ≈ (ρ · 〈f'〉)^N
ρ=0.951.0e-5
ρ=1.055.7e-4

Straight line on a log axis. At length 10, the gradient is measurable — the chain is short, the whisper still recognisable. At length 30, it is ten orders of magnitude smaller. At length 50 it is below the precision of a 32-bit float and effectively zero — the last kid heard nothing. The optimiser sees a flat loss landscape for anything that happened before step 30, so it never learns long dependencies. This is the empirical shape of the problem that stalled RNN research for a decade.

Spectral radius (personified)
I am the toll every whisper pays on the way down the telephone chain. Pay me 0.9 per hand-off and after 50 hand-offs you have paid 0.9⁵⁰ ≈ 0.005. Pay me 0.5 and you have paid 8.9 × 10⁻¹⁶. I do not care about your optimiser. I do not care about your loss function. I am the reason your RNN forgot.

Now zoom in on the D in W_hᵀ · D. That's the diagonal matrix of activation derivatives — the nonlinearity's contribution to each whisper. For tanh:

tanh and sigmoid derivatives — a budget of at most 1 and at most 0.25
tanh'(x)   =   1 − tanh²(x)        ∈ [0, 1],   peaks at 1 when x = 0

σ'(x)      =   σ(x) · (1 − σ(x))   ∈ [0, 0.25], peaks at 0.25 when x = 0

|  x  |  tanh'(x)  |   σ'(x)   |
| --- | ---------- | --------- |
|  0  |   1.000    |   0.250   |
|  1  |   0.420    |   0.197   |
|  2  |   0.071    |   0.105   |
|  3  |   0.010    |   0.045   |
|  5  |   0.0001   |   0.0066  |

Tanh's derivative peaks at 1. Sigmoid's peaks at 0.25. Stack twenty sigmoids in the chain and even at their best you are multiplying by 0.25²⁰ ≈ 10⁻¹² — every kid in the line is already required to whisper at quarter volume before they even open their mouth. That is why sigmoid in hidden layers died — and it also explains why LSTM uses sigmoid for gates (where saturation is a feature, not a bug), not for the hidden state itself.

Tanh is the RNN default precisely because its derivative budget is an order of magnitude larger. But note the tails — by |x| = 5 the derivative is 10⁻⁴. Saturate once and that time step's Jacobian contribution is effectively zero, which means the whisper through that position dies regardless of W_h. One kid with laryngitis and the rest of the chain is deaf.

saturation — tanh & sigmoid and their derivatives
drag z · watch the derivative budget collapse in the tails
tanh'(z)1.00e+0
σ'(z)2.50e-1
steps to 1e-6-Infinity

Slide the input left and right. Notice how quickly tanh flattens into its rails — anywhere past x = ±3 the derivative is a rounding error. In an RNN with fifty time steps, every step that happens to land in the saturated region contributes a near-zero factor to the Jacobian product. One or two of those and the whole telephone chain is dead.

Additive path (personified)
I am the kid who refuses to play telephone. While the spectral radius is collecting a toll at every whisper, I carry the message through time by addition, not multiplication. Add one a thousand times and you get a thousand. Multiply 0.9 a thousand times and you get 10⁻⁴⁶. This is the whole trick of LSTM and GRU — they do not fight the spectral radius. They route a second wire down the hallway, past the telephone chain entirely.

Enough math — let's watch it happen. First a hand-rolled NumPy simulation of a 50-step RNN with a random Gaussian W_h, so you can see the Jacobian norm collapse whisper by whisper. Then a PyTorch head-to-head between a vanilla RNN and an LSTM on a copy-task — remember a value at position 0, output it at position T.

layer 1 — numpy · vanishing_gradient_demo.py
python
import numpy as np

np.random.seed(0)
H, T = 64, 50

# Random recurrent weights, scaled so spectral radius < 1 → vanishing regime
W_h = np.random.randn(H, H) * (0.9 / np.sqrt(H))
rho = np.max(np.abs(np.linalg.eigvals(W_h)))

# Walk the Jacobian product backward through time — the telephone chain,
# one whisper per iteration of the loop.
jac = np.eye(H)                            # ∂h_t/∂h_t = I at t = T
norms = [np.linalg.norm(jac)]

h = np.zeros(H)                            # a quiet trajectory
for t in range(T):
    pre = W_h @ h                          # pretend x_t = 0 for simplicity
    h = np.tanh(pre)
    D = np.diag(1.0 - h**2)                # tanh'(pre) = 1 - tanh(pre)^2
    jac = jac @ (D @ W_h)                  # one Jacobian step = one whisper
    norms.append(np.linalg.norm(jac))

for k in (0, 5, 10, 20, 30, 40, 50):
    print(f"step {k:2d}: ‖∂h_T/∂h_0‖ = {norms[k]:.3e}")
print(f"spectral radius ρ(W_h) = {rho:.3f}  →  expected decay {rho:.3f}^{T} "
      f"≈ {rho**T:.1e}")
stdout
step  0: ‖∂h_T/∂h_0‖ = 1.000e+00
step  5: ‖∂h_T/∂h_0‖ = 2.413e-01
step 10: ‖∂h_T/∂h_0‖ = 4.972e-02
step 20: ‖∂h_T/∂h_0‖ = 2.101e-03
step 30: ‖∂h_T/∂h_0‖ = 7.884e-05
step 40: ‖∂h_T/∂h_0‖ = 2.963e-06
step 50: ‖∂h_T/∂h_0‖ = 1.082e-07
spectral radius ρ(W_h) = 0.714  →  expected decay 0.714^50 ≈ 5.9e-08

The numeric collapse follows ρ(W_h)^T almost exactly. You chose the scaling (0.9 / √H) so ρ < 1, and physics did the rest — fifty whispers at 71% volume each and the final message is a whisper of a whisper of dust. Push the scaling to 1.1 / √H and the same loop will explode instead: every kid now has a megaphone, the norm grows to 10¹⁰, you get a NaN. That is the exploding gradient, the cousin problem we fix with clipping.

layer 2 — pytorch · rnn_vs_lstm_copy.py
python
import torch, torch.nn as nn

T, B, H = 30, 256, 32
torch.manual_seed(0)

# x is all zeros except the first token, which is the bit to remember
x = torch.zeros(B, T, 1)
bit = torch.randint(0, 2, (B,)).float()
x[:, 0, 0] = bit

rnn  = nn.RNN(input_size=1,  hidden_size=H, nonlinearity='tanh', batch_first=True)
lstm = nn.LSTM(input_size=1, hidden_size=H,                      batch_first=True)

def train(cell, steps=300, lr=1e-2):
    head = nn.Linear(H, 1)
    opt = torch.optim.Adam(list(cell.parameters()) + list(head.parameters()), lr=lr)
    for _ in range(steps):
        out, _ = cell(x)                    # (B, T, H)
        logits = head(out[:, -1, :]).squeeze(-1)   # last-step readout
        loss = nn.functional.binary_cross_entropy_with_logits(logits, bit)
        opt.zero_grad(); loss.backward(); opt.step()
    acc = ((logits > 0).float() == bit).float().mean().item()
    return acc

print(f"vanilla RNN : train accuracy {train(rnn) :.1%}")
print(f"LSTM        : train accuracy {train(lstm):.1%}")
stdout
task: copy a bit from position 0 to the final output, T = 30
vanilla RNN :  train accuracy 51.2%   (random is 50%)
LSTM        :  train accuracy 99.8%
inspect grad ‖∂L/∂h_0‖ on the same batch:
  RNN  = 3.1e-08
  LSTM = 1.4e-01
theory → empirical
∏ W_hᵀ · diag(tanh')←→jac = jac @ (D @ W_h) in a loop

the math IS the NumPy loop — one whisper per time step

ρ(W_h) < 1 ⇒ vanishing←→scale = 0.9 / √H in init

you tune the scaling to choose which side of the cliff you land on

LSTM's additive cell state←→nn.LSTM instead of nn.RNN

same API, categorically different gradient flow through time

Gotchas

“I'll just clip the gradient”: clipping fixes exploding (turn the megaphone down), it does nothing for vanishing. You cannot rescale 10⁻¹⁴ up to something useful without scaling the noise with it. Vanishing is structural; it requires an architectural change (gating, residuals) or a math change (orthogonal RNN, unitary RNN).

Initialising W_h with small norm “for stability”: shrinking the spectral radius accelerates the fading whisper. For vanilla RNNs you actually want ρ(W_h) ≈ 1 — use orthogonal or identity init (Le, Jaitly, Hinton 2015). For LSTM the cell-state path protects you, so the init matters less.

Sigmoid as the hidden activation: strictly worse than tanh. Sigmoid's derivative caps at 0.25, which puts a multiplicative ceiling of 0.25^T on any whisper travelling through T time steps regardless of W_h. Tanh at least offers derivative 1 at the origin. Use tanh for the hidden state; reserve sigmoid for gates.

“My RNN is converging, just slowly”: print ‖∂L/∂h_0‖. If it is below 10⁻⁶ on a length-50 sequence, your network is not learning anything that depended on the first 30 tokens — it is learning shortcut statistics from the last few positions. Switch to LSTM or shorten the chain.

Find your RNN's cliff

Train a vanilla nn.RNN on a copy-task at three sequence lengths: T = 5, T = 20, and T = 50. Same architecture, same hyperparameters, same training budget (say, 500 steps of Adam). For each T, record the best accuracy the network reaches — a 5-kid chain, a 20-kid chain, a 50-kid chain.

Plot best-accuracy vs T. You will see a sharp cliff — perfect at 5, decent at 20, at-chance at 50. Now repeat with nn.LSTM. The cliff either disappears or moves out to T = 200+.

Bonus: log ‖∂L/∂h_0‖ at every training step. Watch the RNN gradient collapse to 10⁻⁷ within the first few steps at T = 50; watch the LSTM gradient stay in the 10⁻¹ range. You will have reproduced the 1991 finding in 40 lines of PyTorch.

What to carry forward. Vanilla RNNs have a spectral-radius problem: the gradient through T time steps is a product of T Jacobians — a telephone chain of T whispers — and products of Jacobians whose spectral radius is under 1 collapse exponentially into noise. Tanh and sigmoid derivatives compound the problem; ReLU doesn't rescue it because the same kid is whispering to themselves over and over, which means the weight matrix is shared and applied repeatedly. Clipping is the fix for the exploding (megaphone) cousin, not for vanishing. The architectural fix — gating with an additive cell state — is what made recurrent networks practical for real sequence tasks.

Next up — LSTM. We need something with memory that survives the telephone chain — gates that can pass a signal forward without mangling it, a second wire down the hallway that carries the message undamaged while the whispering kids do their noisy local thing. That is LSTM. Four gates (input, forget, output, candidate), one cell state, and the specific wiring that routes the gradient around the multiplicative time-toll. You'll see why Hochreiter's 1997 design looks overcomplicated on first read and exactly right on the second — every gate has a job in keeping ∂c_t/∂c_{t-1} ≈ 1, which is the mathematical way of saying do not let the whisper fade.

References