Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

L04 - Multi-Head Attention: The Committee of Experts

Why one brain is good, but eight brains are better.


In L03 - Self-Attention, we built the “Search Engine” of the Transformer. We learned how the word “it” can look up the word “animal” to resolve ambiguity.

But there is a limitation. A single self-attention layer acts like a single pair of eyes. It can focus on one aspect of the sentence at a time.

Consider the sentence:

“The chicken didn’t cross the road because it was too wide.”

To understand this fully, the model needs to do two things simultaneously:

  1. Syntactic Analysis: Link “it” to the subject “road” (because roads are wide).

  2. Semantic Analysis: Understand that “wide” is a physical property preventing crossing.

If we only have one attention head, the model has to average these different relationships into a single vector. It muddies the waters.

Multi-Head Attention solves this by giving the model multiple “heads” (independent attention mechanisms) that run in parallel.

By the end of this post, you’ll understand:


Part 1: The Intuition (The Committee)

Think of the embedding dimension (dmodel=512d_{model} = 512) as a massive report containing everything we know about a word.

If we ask a single person to read that report and summarize “grammar,” “tone,” “tense,” and “meaning” all at once, they might miss details.

Instead, we hire a Committee of 8 Experts:

In the Transformer, we don’t just copy the input 8 times. We project the input into 8 different lower-dimensional spaces. This allows each head to specialize.

Let’s visualize this “filtering” process. In the plot below:

<Figure size 1400x800 with 1 Axes>

Part 2: The Multi-Head Pipeline

Now that we understand the “why” (Specialization), let’s look at the “how” (The Pipeline).

The Multi-Head Attention mechanism isn’t a single black box; it is a specific sequence of operations. It allows the model to process information in parallel and then synthesize the results.

The 4-Step Process

  1. Linear Projections (The Split): We don’t just use the raw input. We multiply the input Q,K,VQ, K, V by specific weight matrices (WiQ,WiK,WiVW^Q_i, W^K_i, W^V_i) for each head. This creates the specialized “subspaces” we saw in Part 1.

  2. Independent Attention: Each head runs the standard Scaled Dot-Product Attention independently.

    headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
  3. Concatenation: We take the output vectors from all 8 heads and glue them back together side-by-side.

  4. Final Linear (The Mix): We pass this long concatenated vector through one last linear layer (WOW^O) to blend the insights from all the experts into a single unified vector.

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O

Let’s visualize this flow:


Part 3: Visualizing Multiple Perspectives

Let’s visualize how two different heads might analyze the same sentence.

Sentence: “The cat sat on the mat because it was soft.”

Notice how they highlight completely different parts of the matrix.

<Figure size 1400x600 with 2 Axes>

Part 4: Implementation in PyTorch

Implementing this efficiently requires some tensor gymnastics. We don’t actually run a for loop over the 8 heads. That would be too slow.

Instead, we use a single large matrix multiply and then reshape (view/transpose) the tensor to create a “heads” dimension.

The shape transformation looks like this:

  1. Input: [Batch, Seq_Len, D_Model]

  2. Linear & Reshape: [Batch, Seq_Len, Heads, D_Head]

  3. Transpose: [Batch, Heads, Seq_Len, D_Head]

By swapping axes 1 and 2, we group the “Heads” dimension with the “Batch” dimension. PyTorch then processes all heads in parallel as if they were just extra items in the batch.

Shape Transformation Table

Let’s trace the exact tensor shapes through a concrete example with batch=2, seq=10, d_model=512, heads=8:

OperationShapeDescription
Input x[2, 10, 512]Raw input: 2 sequences, each with 10 tokens, 512-dim embeddings
After W_q(x)[2, 10, 512]Linear projection (still flat)
After .view(2, 10, 8, 64)[2, 10, 8, 64]Reshape: Split 512 dims into 8 heads × 64 dims each
After .transpose(1, 2)[2, 8, 10, 64]Swap seq and heads: Now we have 8 “parallel attention mechanisms”
Attention computation[2, 8, 10, 64]Each head computes attention independently
After .transpose(1, 2)[2, 10, 8, 64]Swap back: Prepare for concatenation
After .contiguous().view(2, 10, 512)[2, 10, 512]Flatten: Merge 8 heads back into single 512-dim vector
After W_o(x)[2, 10, 512]Final projection

The key insight: dimensions 1 and 2 get swapped twice—once to parallelize the heads, and once to merge them back together.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # We define 4 linear layers: Q, K, V projections and the final Output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 1. Project and Split
        # We transform [Batch, Seq, Model] -> [Batch, Seq, Heads, d_k]
        # Then we transpose to [Batch, Heads, Seq, d_k] for matrix multiplication
        Q = self.W_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 2. Scaled Dot-Product Attention (re-using logic from L03)
        # Scores shape: [Batch, Heads, Seq, Seq]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = torch.softmax(scores, dim=-1)
        
        # Apply weights to Values
        # Shape: [Batch, Heads, Seq, d_k]
        attn_output = torch.matmul(attn_weights, V)
        
        # 3. Concatenate
        # Transpose back: [Batch, Seq, Heads, d_k]
        # Flatten: [Batch, Seq, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 4. Final Projection (The "Mix")
        return self.W_o(attn_output)

Summary

  1. Multiple Heads: We split our embedding into hh smaller chunks to allow the model to focus on different linguistic features simultaneously.

  2. Projection: We use learned linear layers (WQ,WK,WVW_Q, W_K, W_V) to project the input into these specialized subspaces.

  3. Parallelism: We use tensor reshaping (view and transpose) to compute attention for all heads at once, rather than looping through them.

Next Up: L05 – Layer Norm & Residuals. We have built the engine (Attention), but if we stack 100 of these layers on top of each other, the gradients will vanish or explode. In L05, we will add the “plumbing” (Normalization and Skip Connections) that allows Deep Learning to actually get deep.