Why one brain is good, but eight brains are better.
In L03 - Self-Attention, we built the “Search Engine” of the Transformer. We learned how the word “it” can look up the word “animal” to resolve ambiguity.
But there is a limitation. A single self-attention layer acts like a single pair of eyes. It can focus on one aspect of the sentence at a time.
Consider the sentence:
“The chicken didn’t cross the road because it was too wide.”
To understand this fully, the model needs to do two things simultaneously:
Syntactic Analysis: Link “it” to the subject “road” (because roads are wide).
Semantic Analysis: Understand that “wide” is a physical property preventing crossing.
If we only have one attention head, the model has to average these different relationships into a single vector. It muddies the waters.
Multi-Head Attention solves this by giving the model multiple “heads” (independent attention mechanisms) that run in parallel.
By the end of this post, you’ll understand:
The intuition of the “Committee of Experts.”
Why we project vectors into different Subspaces.
How to implement the tensor reshaping magic (
viewandtranspose) in PyTorch.
Part 1: The Intuition (The Committee)¶
Think of the embedding dimension () as a massive report containing everything we know about a word.
If we ask a single person to read that report and summarize “grammar,” “tone,” “tense,” and “meaning” all at once, they might miss details.
Instead, we hire a Committee of 8 Experts:
Head 1 (The Linguist): Only looks for Subject-Verb agreement.
Head 2 (The Historian): Looks for past/present tense consistency.
Head 3 (The Translator): Looks for definitions and synonyms.
...
In the Transformer, we don’t just copy the input 8 times. We project the input into 8 different lower-dimensional spaces. This allows each head to specialize.
Let’s visualize this “filtering” process. In the plot below:
The Input (Mixed Info): The large multi-colored bar represents the full word embedding ().
The Split (Equal Parts): We project this into 8 equal-sized subspaces ().
The Result: Each head gets a vector that is 1/8th the size of the original, containing only the specific info it needs.

Part 2: The Multi-Head Pipeline¶
Now that we understand the “why” (Specialization), let’s look at the “how” (The Pipeline).
The Multi-Head Attention mechanism isn’t a single black box; it is a specific sequence of operations. It allows the model to process information in parallel and then synthesize the results.
The 4-Step Process
Linear Projections (The Split): We don’t just use the raw input. We multiply the input by specific weight matrices () for each head. This creates the specialized “subspaces” we saw in Part 1.
Independent Attention: Each head runs the standard Scaled Dot-Product Attention independently.
Concatenation: We take the output vectors from all 8 heads and glue them back together side-by-side.
Final Linear (The Mix): We pass this long concatenated vector through one last linear layer () to blend the insights from all the experts into a single unified vector.
Let’s visualize this flow:
Part 3: Visualizing Multiple Perspectives¶
Let’s visualize how two different heads might analyze the same sentence.
Sentence: “The cat sat on the mat because it was soft.”
Head 1 focuses on the physical relationship (connecting “it” to “mat”).
Head 2 focuses on the actor (connecting “sat” to “cat”).
Notice how they highlight completely different parts of the matrix.

Part 4: Implementation in PyTorch¶
Implementing this efficiently requires some tensor gymnastics. We don’t actually run a for loop over the 8 heads. That would be too slow.
Instead, we use a single large matrix multiply and then reshape (view/transpose) the tensor to create a “heads” dimension.
The shape transformation looks like this:
Input:
[Batch, Seq_Len, D_Model]Linear & Reshape:
[Batch, Seq_Len, Heads, D_Head]Transpose:
[Batch, Heads, Seq_Len, D_Head]
By swapping axes 1 and 2, we group the “Heads” dimension with the “Batch” dimension. PyTorch then processes all heads in parallel as if they were just extra items in the batch.
Shape Transformation Table¶
Let’s trace the exact tensor shapes through a concrete example with batch=2, seq=10, d_model=512, heads=8:
| Operation | Shape | Description |
|---|---|---|
Input x | [2, 10, 512] | Raw input: 2 sequences, each with 10 tokens, 512-dim embeddings |
After W_q(x) | [2, 10, 512] | Linear projection (still flat) |
After .view(2, 10, 8, 64) | [2, 10, 8, 64] | Reshape: Split 512 dims into 8 heads × 64 dims each |
After .transpose(1, 2) | [2, 8, 10, 64] | Swap seq and heads: Now we have 8 “parallel attention mechanisms” |
| Attention computation | [2, 8, 10, 64] | Each head computes attention independently |
After .transpose(1, 2) | [2, 10, 8, 64] | Swap back: Prepare for concatenation |
After .contiguous().view(2, 10, 512) | [2, 10, 512] | Flatten: Merge 8 heads back into single 512-dim vector |
After W_o(x) | [2, 10, 512] | Final projection |
The key insight: dimensions 1 and 2 get swapped twice—once to parallelize the heads, and once to merge them back together.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# We define 4 linear layers: Q, K, V projections and the final Output
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 1. Project and Split
# We transform [Batch, Seq, Model] -> [Batch, Seq, Heads, d_k]
# Then we transpose to [Batch, Heads, Seq, d_k] for matrix multiplication
Q = self.W_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. Scaled Dot-Product Attention (re-using logic from L03)
# Scores shape: [Batch, Heads, Seq, Seq]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
# Apply weights to Values
# Shape: [Batch, Heads, Seq, d_k]
attn_output = torch.matmul(attn_weights, V)
# 3. Concatenate
# Transpose back: [Batch, Seq, Heads, d_k]
# Flatten: [Batch, Seq, d_model]
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 4. Final Projection (The "Mix")
return self.W_o(attn_output)Summary¶
Multiple Heads: We split our embedding into smaller chunks to allow the model to focus on different linguistic features simultaneously.
Projection: We use learned linear layers () to project the input into these specialized subspaces.
Parallelism: We use tensor reshaping (
viewandtranspose) to compute attention for all heads at once, rather than looping through them.
Next Up: L05 – Layer Norm & Residuals. We have built the engine (Attention), but if we stack 100 of these layers on top of each other, the gradients will vanish or explode. In L05, we will add the “plumbing” (Normalization and Skip Connections) that allows Deep Learning to actually get deep.