Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

L05 - Normalization & Residuals: The Plumbing of Deep Networks [DRAFT]

How to stop gradients from vanishing and signals from exploding


We have built the Multi-Head Attention engine, but there is a problem. In a deep LLM, we stack these layers dozens of times. As the data passes through these transformations, the numbers can drift: they might become tiny (vanishing) or massive (exploding).

If the numbers get weird, the model stops learning. To fix this, we use two critical “plumbing” techniques:

  1. Residual (Skip) Connections: “Don’t forget what you just learned.”

  2. Layer Normalization: “Keep the numbers in a healthy range.”

By the end of this post, you’ll understand:


Part 1: Residual Connections (The Skip)

In a standard network, data flows like this: . In a Transformer, we do this: .

We literally add the input back to the output.

Why?

Imagine you are trying to describe a complex concept. If you only give the “transformed” explanation, you might lose the original context. By adding the input back, we provide a “highway” for the original signal to travel through.

Visualizing Gradient Flow: With vs. Without Residuals

Let’s see the dramatic difference residual connections make in a deep network:

<Figure size 1400x500 with 2 Axes>

Key observations:


Part 2: Layer Normalization (The Leveler)

LayerNorm ensures that for every token, the mean of its features is 0 and the standard deviation is 1. This prevents any single feature or layer from dominating the calculation.

Unlike BatchNorm (common in CNNs), LayerNorm calculates statistics across the features of a single token. This makes it perfect for sequences of varying lengths.

LayerNorm vs. BatchNorm: What’s the Difference?

Both normalization techniques aim to stabilize training, but they compute statistics over different dimensions:

AspectBatchNormLayerNorm
Normalizes acrossBatch dimension (across examples)Feature dimension (within each example)
Input shape[Batch, Features] or [Batch, Channels, Height, Width][Batch, Seq_Len, Features]
Statistics computedMean/Std for each feature across all examples in batchMean/Std for each example across all features
DependenciesRequires large batches to get good statisticsWorks with batch size = 1
Typical useCNNs (image tasks)Transformers, RNNs (sequence tasks)

Why LayerNorm for Transformers?

Visual Comparison:

BatchNorm (shape [4, 512]):          LayerNorm (shape [4, 512]):
┌─────────────────┐                  ┌─────────────────┐
│ Example 1       │                  │ Example 1       │ ← Normalize these 512 values
│ Example 2       │ ↑                │ Example 2       │ ← Normalize these 512 values
│ Example 3       │ │ Normalize      │ Example 3       │ ← Normalize these 512 values
│ Example 4       │ │ each column    │ Example 4       │ ← Normalize these 512 values
└─────────────────┘ ↓                └─────────────────┘

The Formula

For a vector :

x^=xμσ2+ϵγ+β\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
  1. Calculate mean () and variance () of the features.

  2. Subtract mean and divide by standard deviation.

  3. Learnable Parameters: (scale) and (shift) allow the model to “undo” the normalization if it decides that a different range is better for learning.


Part 3: Visualizing the “Add & Norm”

Let’s see how LayerNorm tames wild values.

<Figure size 1400x500 with 2 Axes>

Part 4: Building LayerNorm from Scratch

While PyTorch has nn.LayerNorm, building it yourself helps you understand exactly where those learnable parameters () live.

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        # x shape: [batch, seq_len, d_model]
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        
        # Normalize
        x_norm = (x - mean) / (std + self.eps)
        
        # Scale and Shift
        return self.gamma * x_norm + self.beta


Part 5: Pre-Norm vs. Post-Norm - Where to Place LayerNorm?

There are two ways to combine residuals and normalization in a Transformer block, and the choice has significant training implications:

Post-Norm (Original Transformer)

x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FeedForward(x))

How it works:

  1. Apply the transformation (Attention or FFN)

  2. Add the residual

  3. Normalize the result

Characteristics:

Pre-Norm (Modern Standard)

x = x + Attention(LayerNorm(x))
x = x + FeedForward(LayerNorm(x))

How it works:

  1. Normalize the input first

  2. Apply the transformation

  3. Add the residual

Characteristics:

Visual Comparison

Post-Norm Architecture
Pre-Norm Architecture

Why Pre-Norm Won:


Summary

  1. Residual Connections create a “high-speed rail” for the signal, preventing the vanishing gradient problem through the “+1” term in gradients.

  2. LayerNorm re-centers the data at every step, keeping the optimization process stable by normalizing across features rather than across batches.

  3. Pre-Norm vs. Post-Norm: Most modern LLMs use Pre-Norm (normalize before the sub-layer) because it’s more stable to train and less sensitive to hyperparameters.

Next Up: L06 – The Causal Mask. When training a model to predict the next word, how do we stop it from “cheating” by looking at the answer? We’ll build the triangular mask that hides the future.