Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

L03 - Self-Attention: The Search Engine of Language

Building the “brain” of the Transformer: How words talk to each other.


In L02 - Embeddings, we turned words into vectors and gave them positions. But there is still a fatal flaw: each word gets the same vector every time, regardless of context.

Consider the word “Bank”.

  1. “The bank of the river.” (Nature)

  2. “The bank approved the loan.” (Finance)

In a static embedding layer (a simple lookup table), the vector for “bank” is identical in both sentences. The model sees the same token ID → same vector, regardless of whether “bank” appears near “river” or “loan”.

<Figure size 1400x600 with 2 Axes>

The Problem Visualized:

Notice the red square labeled “bank (static)” is at the EXACT SAME POSITION in both plots. Whether “bank” appears in a sentence about rivers or loans, the static embedding lookup table returns the identical vector.

What We Need (shown by dashed arrows):

Self-Attention is the mechanism that enables this context-dependent shift—it allows “bank” to dynamically adjust its representation based on surrounding words.

By the end of this post, you’ll understand:


Part 1: The Intuition

Solving the “Bank” Problem:

We’ve seen that static embeddings give “bank” the same vector whether it appears with “river” or “loan”. How does attention fix this? It allows each word to look at its neighbors and adjust its meaning based on what it finds.

The math of Attention can look scary, but the concept is simple. It is a Soft Database Lookup.

How Attention Works: The Search Process

Let’s trace how “bank” would use attention to shift its meaning in “The bank of the river”:

  1. “bank” generates its Query: “What context am I in?”

  2. It compares this Query against every other word’s Key:

    • “river” Key: “I’m a geographic feature” → High match!

    • “of” Key: “I’m a preposition” → Low match

    • “the” Key: “I’m an article” → Low match

  3. “bank” finds the highest match with “river”

  4. It extracts the Value from “river” (its semantic meaning) and adds it to its own representation

Now, the vector for “bank” is no longer just the static embedding; it is “bank + a lot of ‘river’ + a little bit of ‘the’ and ‘of’”. The representation has shifted toward the nature/geography meaning!

Here’s this process visualized. The diagram shows the complete attention computation: Q × K^T → Softmax → Weighted Sum of V. Notice how “bank” attends most strongly to “river” (0.50 weight), which disambiguates it toward the geographical meaning!

<Figure size 1800x1000 with 1 Axes>

Key Insights from the Visualization:

  1. Q, K, V are Tables: Each has one row per token, with columns representing dimensions (typically 512 in real transformers, though we show just a few here for clarity).

  2. Query-Key Matching (Step 1): When “bank” (the query token) wants to understand its context, it compares its Query vector Q[“bank”] against ALL Key vectors. The comparison produces attention scores:

    • Q[“bank”] · K[“river”] = 0.50 (highest: disambiguates to geographical meaning!)

    • Q[“bank”] · K[“of”] = 0.30 (preposition - contextual glue)

    • Q[“bank”] · K[“The”] = 0.15 (less relevant determiner)

  3. Retrieve Values (Step 2): These scores become weights that determine how much of each Value vector to retrieve. The final output for “bank” is a weighted combination: primarily V[“river”] (50%), with contributions from V[“of”] (30%) and other tokens. This shifts “bank” toward its geographical meaning!

  4. Real Dimensions: While we show 3-4 dimensions for clarity, real transformers use 512 or more dimensions, allowing for much richer semantic relationships.


Part 2: Q, K, V are Learned Projections

In the visualization above, we showed Q, K, V as three separate tables. But where do these tables come from? A critical point that’s often misunderstood:

Let’s visualize how these projections work:

<Figure size 1200x800 with 1 Axes>

Now let’s see this in code:

# REALITY: Q, K, V come from learned projections of the embeddings
import torch

# Input: embeddings for "bank", "river", "loan"
# Static embeddings (in reality, these would come from the embedding layer)
bank_embedding = torch.tensor([0.5, 0.3, 0.8, 0.2])
river_embedding = torch.tensor([0.2, 0.9, 0.1, 0.3])
loan_embedding = torch.tensor([0.8, 0.1, 0.4, 0.7])

embeddings = torch.stack([bank_embedding, river_embedding, loan_embedding])  # [3, 4]
print("Input embeddings [3 tokens, 4 dims]:")
print(embeddings)
print()

# Learned projection matrices (in real models, these are trained)
# For this demo, we'll create random matrices
torch.manual_seed(42)
d_model = 4  # embedding dimension
W_Q = torch.randn(d_model, d_model) * 0.5  # [4, 4]
W_K = torch.randn(d_model, d_model) * 0.5  # [4, 4]
W_V = torch.randn(d_model, d_model) * 0.5  # [4, 4]

# Project embeddings to Q, K, V
Q = embeddings @ W_Q  # [3, 4] @ [4, 4] = [3, 4]
K = embeddings @ W_K  # [3, 4]
V = embeddings @ W_V  # [3, 4]

print("After learned projections:")
print(f"  Q shape: {Q.shape}  (each embedding transformed by W_Q)")
print(f"  K shape: {K.shape}  (each embedding transformed by W_K)")
print(f"  V shape: {V.shape}  (each embedding transformed by W_V)")
print()
print("Q (queries):")
print(Q)
print()
print("Notice: Q is DIFFERENT from the input embeddings!")
print("The projection matrices mixed and transformed the original features.")
Input embeddings [3 tokens, 4 dims]:
tensor([[0.5000, 0.3000, 0.8000, 0.2000],
        [0.2000, 0.9000, 0.1000, 0.3000],
        [0.8000, 0.1000, 0.4000, 0.7000]])

After learned projections:
  Q shape: torch.Size([3, 4])  (each embedding transformed by W_Q)
  K shape: torch.Size([3, 4])  (each embedding transformed by W_K)
  V shape: torch.Size([3, 4])  (each embedding transformed by W_V)

Q (queries):
tensor([[ 0.2098,  0.7902, -0.0152, -1.2523],
        [ 0.3512, -0.4083, -0.0643, -0.8885],
        [ 0.3995,  0.6671,  0.0105, -0.9363]])

Notice: Q is DIFFERENT from the input embeddings!
The projection matrices mixed and transformed the original features.

Part 3: The Key Advantage - Parallelism

Everything Happens at Once

Remember from L02: “The attention mechanism is parallel. It looks at every word in a sentence at the exact same time.”

This is the breakthrough that makes Transformers faster AND better at understanding language than older architectures like RNNs (Recurrent Neural Networks).

How RNNs Process a Sentence (Sequential):

RNNs maintain a “hidden state”—a vector that accumulates information from all previous words. At each step, the hidden state combines the current word with everything seen so far.

Let’s trace a concrete example: pronoun resolution. When the model processes “it”, how does it figure out that “it” refers to “bank”? We’ll follow how information flows through the hidden states.

Input: "The bank approved the loan because it was well-capitalized"
        ↑    ↑                              ↑
      word 1  word 2                       word 7

Step 1: "The"      → hidden_state_1 = f(embedding("The"))
                      ↓ Contains: [info about "The"]

Step 2: "bank"     → hidden_state_2 = f(embedding("bank"), hidden_state_1)
                      ↓ Contains: [info about "The", "bank"] compressed into one vector

Step 3: "approved" → hidden_state_3 = f(embedding("approved"), hidden_state_2)
                      ↓ Contains: [info about "The", "bank", "approved"] compressed

...

Step 7: "it"       → hidden_state_7 = f(embedding("it"), hidden_state_6)
                      ↓ Contains: [ALL 7 words] compressed into a fixed-size vector

Problem: To understand what "it" refers to, information about "bank" (word 2)
must pass through a chain of compressions before reaching "it" (word 7):

  hidden_state_2 (contains "bank" info)
    ↓ compressed with "approved"
  hidden_state_3
    ↓ compressed with "the"
  hidden_state_4
    ↓ compressed with "loan"
  hidden_state_5
    ↓ compressed with "because"
  hidden_state_6 (used by "it" at step 7)

Information about "bank" has been compressed through 4 intermediate mixing steps.
The more steps between "bank" and "it", the more diluted the information becomes.
This is the "vanishing gradient" problem (where gradient signals become too small to effectively update earlier layers during backpropagation).

Total: 7 sequential steps (MUST run one-by-one)

The Hidden State Bottleneck:

The hidden state is accumulative (it tries to remember everything), but it achieves this through compression into a fixed-size vector (typically 512 or 1024 dimensions).

Think of it like this: After reading “The bank approved the loan because it”, the RNN must squeeze all understanding of these 7 words—their meanings, relationships, syntactic roles—into a single vector of fixed size. Then it must use this compressed summary to process “was” and “well-capitalized”.

It’s like trying to fit an entire Wikipedia article into a tweet, then using only that tweet to write the next paragraph. Some information inevitably gets lost or diluted.

How Attention Processes the Same Sentence (Parallel):

Now let’s see how attention solves the same pronoun resolution task: when “it” needs to figure out what it refers to.

Instead of passing information through a chain of hidden states, attention allows direct connections between any two words.

Input: "The bank approved the loan because it was well-capitalized"

Single Step: ALL words computed simultaneously via matrix operations:

"it" (word 7) compares its Query directly against ALL Keys:
  Q("it") · K("The")      = 0.05  → Low attention weight
  Q("it") · K("bank")     = 0.82  → HIGH attention weight! ✓
  Q("it") · K("approved") = 0.08  → Low attention weight
  Q("it") · K("the")      = 0.02  → Low attention weight
  Q("it") · K("loan")     = 0.15  → Medium attention weight
  ...

Result: "it" can look DIRECTLY at "bank" (word 2) without any intermediate recurrent steps.
No repeated hidden-state bottleneck — just one parallel compare-and-mix step per layer.

Every other word does the same computation simultaneously:
  - "The" attends to all words
  - "bank" attends to all words
  - "approved" attends to all words
  ... (all computed in parallel)

Total: 1 parallel step (all comparisons at once, then weighted sum)

Why Direct Access Matters:

  1. No information loss: “it” doesn’t need to hope that information about “bank” survived 4 intermediate compressions (hidden states 3→4→5→6)

  2. Long-range dependencies: Works just as well for word 100 referring to word 1 as for adjacent words

  3. Symmetry: “bank” can attend to “loan” just as easily as “loan” attends to “bank”

  4. Multiple relationships: Each word can attend strongly to multiple other words simultaneously (through the weighted sum)

Concrete Example – Pronoun Resolution (soft, not magic):

Consider: “The artist gave the musician a score because she loved her composition.”

This sentence is genuinely ambiguous: humans can reasonably map “she/her” to either person depending on how they interpret “score/composition”.

Why This Works:

The attention mechanism achieves parallelism through matrix multiplication. When we compute QKTQK^T (which you’ll see shortly), we’re not looping through words one-by-one. Instead:

  1. Every word generates its Query, Key, and Value simultaneously (one matrix operation)

  2. Every Query compares against every Key simultaneously (another matrix operation)

  3. Every word gets its context-aware representation simultaneously (final matrix operation)

Modern GPUs are optimized for matrix operations, so computing attention for 100 words in parallel is barely slower than computing it for 10 words. This is why Transformers can handle such long contexts efficiently.

Now let’s see the math that makes this parallelism possible.


Part 4: The Math of Similarity

In Part 1, we saw the intuition behind Q/K/V. In Part 2, we learned these are learned projections. In Part 3, we explored parallelism. Now let’s dive into the math: How do we “compare” two vectors? How do we measure if a Query is similar to a Key?

The answer: the Dot Product.

The dot product is a mathematical operation that measures alignment between two vectors. If two vectors point in the same direction, the result is large and positive. If they point in opposite directions, it is negative. If they’re perpendicular, the result is zero.

This is how attention computes “relevance”: Q(“it”) · K(“bank”) gives us a score measuring how much “it” should attend to “bank”.

Visualizing the “Magnitude Problem”

However, there is a catch. The dot product captures both alignment and magnitude. Before we look at the formula, let’s visualize why this is dangerous for neural networks.

In the plot below, we compare a Query (Blue) against three different Keys.

First, let’s compute the scores numerically to see the problem:

# Computing dot products to measure similarity
import torch

Q = torch.tensor([3.0, 1.0])
K1_short = torch.tensor([1.5, 0.5])
K2_long = torch.tensor([4.5, 1.5])
K3_misaligned = torch.tensor([1.0, 4.0])

print("🎯 Dot Product Scores:")
print(f"  K1 (short, aligned):     {torch.dot(Q, K1_short):.1f}")
print(f"  K2 (long, aligned):      {torch.dot(Q, K2_long):.1f}")
print(f"  K3 (misaligned):         {torch.dot(Q, K3_misaligned):.1f}")
print(f"\n⚠️  K3 (misaligned) scores higher than K1 (aligned)!")
print(f"     This is the magnitude problem we need to fix with scaling.")
🎯 Dot Product Scores:
  K1 (short, aligned):     5.0
  K2 (long, aligned):      15.0
  K3 (misaligned):         7.0

⚠️  K3 (misaligned) scores higher than K1 (aligned)!
     This is the magnitude problem we need to fix with scaling.
<Figure size 1000x600 with 1 Axes>

Why this is a problem: Look at the score for K3 (Misaligned). It scored 7.0. Now look at K1 (Short). It scored 5.0.

The “Misaligned” vector beat the “Perfectly Aligned” vector simply because it was longer. If we don’t fix this, our model will prioritize “loud” signals (large numbers) over “correct” signals (aligned meaning).

We fix this by Scaling: we divide the result by the square root of the dimension (dk\sqrt{d_k}), where dkd_k is the dimensionality of the Q and K vectors. This normalizes the scores so the model focuses on alignment, not magnitude.

Let’s see scaling in action with a concrete example:

# Why scaling matters: A concrete example
import torch

Q = torch.tensor([3.0, 1.0])
K_aligned = torch.tensor([3.0, 1.0])
K_misaligned = torch.tensor([1.0, 4.0])

score_aligned = torch.dot(Q, K_aligned)
score_misaligned = torch.dot(Q, K_misaligned)

print("Without Scaling:")
print(f"  Aligned score:     {score_aligned:.1f}")
print(f"  Misaligned score:  {score_misaligned:.1f}")
print(f"  Ratio:             {score_aligned/score_misaligned:.2f}x\n")

# Apply scaling
d_k = 2  # Dimensionality of Q and K vectors (both are 2D in this example)
scaled_aligned = score_aligned / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
scaled_misaligned = score_misaligned / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

print("With Scaling (÷√2 ≈ ÷1.41):")
print(f"  Aligned score:     {scaled_aligned:.2f}")
print(f"  Misaligned score:  {scaled_misaligned:.2f}")
print(f"  Ratio:             {scaled_aligned/scaled_misaligned:.2f}x")
print(f"\n✓ The ratio stays the same, but values are controlled")
Without Scaling:
  Aligned score:     10.0
  Misaligned score:  7.0
  Ratio:             1.43x

With Scaling (÷√2 ≈ ÷1.41):
  Aligned score:     7.07
  Misaligned score:  4.95
  Ratio:             1.43x

✓ The ratio stays the same, but values are controlled

The Formula

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Let’s break down the equation step-by-step:

  1. The Scores (QKTQK^T): We multiply the Query of the current word by the Keys of all words. (The TT superscript means “transpose”—we flip rows and columns of the K matrix so the dimensions align for multiplication.)

  2. The Scaling (dk\sqrt{d_k}): We shrink the scores by dividing by dk\sqrt{d_k} (where dkd_k is the dimension of the Q and K vectors) to prevent exploding values.

  3. Softmax (The Probability): We convert scores into probabilities that sum to 1.0.

  4. The Weighted Sum (VV): We multiply the probabilities by the Values to get the final context vector.

Example Walkthrough: Crunching the Numbers

Let’s trace the math using vectors from the plot above. We’ll use the pronoun resolution example: when “it” (query) attends to “animal”, “street”, and “because” (keys). Note that we’re reusing the same Q=[3,1] vector from the magnitude visualization—now applying it to a concrete language example.

Recall that:

Inputs:

Let’s compute all four steps in executable code:

import torch
import torch.nn.functional as F

# Inputs
Q = torch.tensor([3.0, 1.0])
K = {
    'animal': torch.tensor([3.0, 1.0]),
    'street': torch.tensor([1.0, 4.0]),
    'because': torch.tensor([1.5, 0.5])
}
V = {
    'animal': torch.tensor([2.0, 1.5]),
    'street': torch.tensor([0.5, 0.3]),
    'because': torch.tensor([-0.5, 1.2])
}

print("Step 1: Compute Dot Products (Raw Scores)")
scores = {name: torch.dot(Q, k).item() for name, k in K.items()}
for name, score in scores.items():
    print(f"  {name:8s}: {score:.1f}")

print("\nStep 2: Scale by √d_k")
d_k = 2  # Dimensionality of Q and K (both are 2D in this example)
scaled = {name: score / torch.sqrt(torch.tensor(d_k)).item() for name, score in scores.items()}
for name, score in scaled.items():
    print(f"  {name:8s}: {score:.2f}")

print("\nStep 3: Softmax (Convert to Probabilities)")
scaled_tensor = torch.tensor(list(scaled.values()))
weights_tensor = F.softmax(scaled_tensor, dim=0)

for name, weight in zip(K.keys(), weights_tensor):
    print(f"  {name:8s}: {weight:.2f} ({weight*100:.0f}%)")

print("\nStep 4: Weighted Sum (Combine Values)")
V_stacked = torch.stack([V[name] for name in K.keys()])
context = torch.sum(weights_tensor.unsqueeze(1) * V_stacked, dim=0)
print(f"  Final context vector: [{context[0]:.2f}, {context[1]:.2f}]")
print(f"\n✓ Context is dominated by 'animal' (87% weight)")
Step 1: Compute Dot Products (Raw Scores)
  animal  : 10.0
  street  : 7.0
  because : 5.0

Step 2: Scale by √d_k
  animal  : 7.07
  street  : 4.95
  because : 3.54

Step 3: Softmax (Convert to Probabilities)
  animal  : 0.87 (87%)
  street  : 0.10 (10%)
  because : 0.03 (3%)

Step 4: Weighted Sum (Combine Values)
  Final context vector: [1.78, 1.37]

✓ Context is dominated by 'animal' (87% weight)

Now let’s walk through each step in detail:

Step 1: The Dot Product (QKTQK^T) - Computing Raw Scores

Step 2: Scaling (dk\sqrt{d_k}) We divide by 21.41\sqrt{2} \approx 1.41.

Step 3: Softmax We exponentiate and normalize to get percentages using the Softmax formula:

P(xi)=exiexjP(x_i) = \frac{e^{x_i}}{\sum e^{x_j}}
P1=e7.09e7.09+e4.96+e3.5411991199+142+341199137587%(animal)P2=e4.96e7.09+e4.96+e3.54142137510%(street)P3=e3.54e7.09+e4.96+e3.543413753%(because)\begin{align} P_1 &= \frac{e^{7.09}}{e^{7.09} + e^{4.96} + e^{3.54}} \\ &\approx \frac{1199}{1199 + 142 + 34} \\ &\approx \frac{1199}{1375} \\ &\approx \mathbf{87\%} \quad \text{(animal)} \\[1em] P_2 &= \frac{e^{4.96}}{e^{7.09} + e^{4.96} + e^{3.54}} \\ &\approx \frac{142}{1375} \\ &\approx \mathbf{10\%} \quad \text{(street)} \\[1em] P_3 &= \frac{e^{3.54}}{e^{7.09} + e^{4.96} + e^{3.54}} \\ &\approx \frac{34}{1375} \\ &\approx \mathbf{3\%} \quad \text{(because)} \end{align}

📘 Want to learn more about Softmax?

For a deeper dive into how softmax works, why we use exponentials, and numerical stability techniques, see our dedicated tutorial: Softmax: From Scores to Probabilities

Step 4: Weighted Sum (Combining Values) Now we multiply each attention weight by its corresponding Value vector and sum them:

Context=0.87×Vanimal+0.10×Vstreet+0.03×Vbecause=0.87×[2.0,1.5]+0.10×[0.5,0.3]+0.03×[0.5,1.2]=[1.74,1.31]+[0.05,0.03]+[0.01,0.04][1.78,1.37]\begin{align} \text{Context} &= 0.87 \times V_{\text{animal}} + 0.10 \times V_{\text{street}} + 0.03 \times V_{\text{because}} \\ &= 0.87 \times [2.0, 1.5] + 0.10 \times [0.5, 0.3] + 0.03 \times [-0.5, 1.2] \\ &= [1.74, 1.31] + [0.05, 0.03] + [-0.01, 0.04] \\ &\approx [1.78, 1.37] \end{align}

Notice how the mechanism successfully identified the aligned vector (“animal”) as the important one, giving it 87% of the attention weights (recall: weights are the post-softmax probabilities, while scores are the pre-softmax logits 7.09, 4.96, and 3.54). The final context vector [1.78, 1.37] is dominated by V_animal’s contribution, as we’ll see visualized below.


Part 5: Visualizing the Attention Map

We’ve explored attention through different lenses—from the “bank” disambiguation problem to the worked example with pronoun resolution above. Now let’s see the full attention pattern for our pronoun resolution example as a heatmap.

In trained models, attention patterns emerge that capture semantic relationships. The heatmap below shows a simplified example to illustrate what we might expect: brighter colors represent higher attention weights (post-softmax probabilities).

<Figure size 1000x800 with 2 Axes>

Before we implement the full attention class, let’s see how attention works with real tensors:

# Test attention mechanism step-by-step with PyTorch
import torch
import torch.nn.functional as F

# Create simple 3-token sequence with 4-dimensional embeddings
Q = torch.tensor([[3.0, 1.0, 0.0, 0.0],
                  [1.0, 4.0, 0.0, 0.0],
                  [2.0, 2.0, 0.0, 0.0]])  # [3 tokens, 4 dims]
K = Q.clone()  # For simplicity, keys = queries
V = torch.tensor([[1.0, 0.0, 0.0, 0.0],
                  [0.0, 1.0, 0.0, 0.0],
                  [0.5, 0.5, 0.0, 0.0]])  # [3 tokens, 4 dims]

d_k = Q.size(-1)  # Dimensionality of Q, K, and V (all are 4D here)

# Step by step
scores = torch.matmul(Q, K.T) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
print("Scaled Scores (each token attending to all tokens):")
print(scores)
print()

weights = F.softmax(scores, dim=-1)
print("Attention Weights (after softmax):")
print(weights)
print(f"Row sums (should be 1.0): {weights.sum(dim=-1)}")
print()

output = torch.matmul(weights, V)
print("Output (context-aware representations):")
print(output)
Scaled Scores (each token attending to all tokens):
tensor([[5.0000, 3.5000, 4.0000],
        [3.5000, 8.5000, 5.0000],
        [4.0000, 5.0000, 4.0000]])

Attention Weights (after softmax):
tensor([[0.6285, 0.1402, 0.2312],
        [0.0065, 0.9644, 0.0291],
        [0.2119, 0.5761, 0.2119]])
Row sums (should be 1.0): tensor([1.0000, 1.0000, 1.0000])

Output (context-aware representations):
tensor([[0.7441, 0.2559, 0.0000, 0.0000],
        [0.0211, 0.9789, 0.0000, 0.0000],
        [0.3179, 0.6821, 0.0000, 0.0000]])

Part 6: Batching for GPU Efficiency

So far, we’ve worked with single sequences shaped [S, D] - one sentence at a time. In practice, we process multiple sequences simultaneously in a batch.

Why Batch?

GPUs are designed for parallel computation. Processing 32 sequences takes roughly the same time as processing 1 sequence! This is why deep learning libraries operate on batches: [B, S, D] where:

Here’s what “token”, “sequence”, and “batch” mean:

<Figure size 1400x520 with 1 Axes>

Now let’s see how batching works in our PyTorch implementation...


Part 7: Implementation in PyTorch

We’ve seen the intuition (Q/K/V database lookup), learned projections, parallelism, the math (dot products and scaling), visualization (attention heatmaps), and batching. Now let’s see how remarkably simple the actual code is.

We can implement this entire mechanism in fewer than 20 lines of code.

The Attention Formula (reminder):

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Note the use of masked_fill, which we will use in L06 - The Causal Mask to prevent the model from “cheating” by looking at future words.

import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k  # Dimensionality of Q, K, and V vectors

    def forward(self, q, k, v, mask=None):
        # 1. Calculate the Dot Product (Scores)
        # q: [batch, seq, d_k]
        # k: [batch, seq, d_k]
        # v: [batch, seq, d_k]
        # scores shape: [batch, seq, seq]
        # Scale by sqrt(d_k) where d_k is the dimensionality of Q, K, and V
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 2. Apply Mask (Optional - vital for GPT!)
        if mask is not None:
            # We use a very large negative number so Softmax turns it to zero
            scores = scores.masked_fill(mask == 0, -1e9)

        # 3. Softmax to get probabilities (0.0 to 1.0)
        attn_weights = torch.softmax(scores, dim=-1)  # [batch, seq, seq]

        # 4. Multiply by Values to get the weighted context
        output = torch.matmul(attn_weights, v)  # [batch, seq, d_k]
        
        return output, attn_weights

Using the Attention Layer

Now let’s see how to use this class in practice. We need to create the projection layers and show the complete flow from embeddings to attention output:

import torch
import torch.nn as nn

# Example: Process a batch of 2 sequences, each with 10 tokens
# These values are intentionally small for demonstration purposes
batch_size = 2     # Small for demo (production: 8-512 depending on GPU memory)
seq_len = 10       # Short for demo (production: 128-2048+ depending on model)
d_model = 512      # Standard embedding dimension (used in BERT-base, GPT-2)
d_k = 64           # Dimensionality of Q, K, and V (= d_model / num_heads in multi-head)

# In production, hyperparameters are chosen based on:
# - batch_size: GPU memory constraints (larger = faster training but needs more VRAM)
# - seq_len: Task requirements (longer = more context but quadratic memory cost)
# - d_model: Model capacity (512 for base models, 768-1024 for large, 2048+ for XL)
# - d_k: Usually d_model / num_heads (e.g., 512 / 8 = 64 for 8-head attention)

# Step 1: Start with embeddings (normally from an embedding layer)
# Shape: [batch, seq, d_model]
embeddings = torch.randn(batch_size, seq_len, d_model)

# Step 2: Create the projection layers (learned during training)
# These transform embeddings into Q, K, V
W_q = nn.Linear(d_model, d_k, bias=False)
W_k = nn.Linear(d_model, d_k, bias=False)
W_v = nn.Linear(d_model, d_k, bias=False)

# Step 3: Project embeddings to create Q, K, V
# This is the X × W_Q, X × W_K, X × W_V we discussed earlier
# All three have the same dimensionality d_k
q = W_q(embeddings)  # [batch, seq, d_k]
k = W_k(embeddings)  # [batch, seq, d_k]
v = W_v(embeddings)  # [batch, seq, d_k]

# Step 4: Run attention
attention = ScaledDotProductAttention(d_k=d_k)
output, attn_weights = attention(q, k, v)

print("=" * 70)
print("ATTENTION COMPUTATION")
print("=" * 70)
print("Formula: Attention(Q,K,V) = softmax(QK^T/√d_k) V")
print()
print(f"Input embeddings shape: {embeddings.shape}")  # [2, 10, 512]
print(f"Q, K, V shapes: {q.shape}")                   # [2, 10, 64]
print()
print(f"Attention weights shape: {attn_weights.shape}") # [2, 10, 10] = [B, S, S]
print("  ↳ Result of: softmax(QK^T/√d_k)")
print("  ↳ Shape: [batch, query_pos, key_pos]")
print("  ↳ Each query attends to all keys (sums to 1.0)")
print()
print("  Example: For batch 0, position 3:")
print("    attn_weights[0, 3, 0] = how much pos 3 attends to pos 0")
print("    attn_weights[0, 3, 1] = how much pos 3 attends to pos 1")
print("    ...")
print("    attn_weights[0, 3, 9] = how much pos 3 attends to pos 9")
print()
print(f"Attention output shape: {output.shape}")       # [2, 10, 64]
print("  ↳ Result of: weights @ V")
print("  ↳ Final context-aware token representations")
======================================================================
ATTENTION COMPUTATION
======================================================================
Formula: Attention(Q,K,V) = softmax(QK^T/√d_k) V

Input embeddings shape: torch.Size([2, 10, 512])
Q, K, V shapes: torch.Size([2, 10, 64])

Attention weights shape: torch.Size([2, 10, 10])
  ↳ Result of: softmax(QK^T/√d_k)
  ↳ Shape: [batch, query_pos, key_pos]
  ↳ Each query attends to all keys (sums to 1.0)

  Example: For batch 0, position 3:
    attn_weights[0, 3, 0] = how much pos 3 attends to pos 0
    attn_weights[0, 3, 1] = how much pos 3 attends to pos 1
    ...
    attn_weights[0, 3, 9] = how much pos 3 attends to pos 9

Attention output shape: torch.Size([2, 10, 64])
  ↳ Result of: weights @ V
  ↳ Final context-aware token representations

Key Points:

  1. Embeddings (512D) are the starting point from L02

  2. Projection layers (W_q, W_k, W_v) transform embeddings into smaller Q, K, V vectors (64D in this example)

  3. Attention operates on these projected vectors

  4. The attention weights show how much each token attends to every other token


Summary

  1. Context Matters: Standard embeddings are static. Attention makes them dynamic by allowing words to look at each other.

  2. Q, K, V are Learned Projections (NOT raw inputs!):

    • Critical misconception to avoid: Q, K, V are NOT the input embeddings

    • They are computed via learned weight matrices: Q=XWQQ = X \cdot W^Q, K=XWKK = X \cdot W^K, V=XWVV = X \cdot W^V

    • These projections transform the input into three specialized “views”:

      • Q (Query): What I’m searching for

      • K (Key): What I advertise about myself

      • V (Value): Content to extract and pass along

    • The projection matrices (WQ,WK,WVW^Q, W^K, W^V) are learned during training

  3. The Attention Formula:

    Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
    • Compute similarity scores via dot product (QKTQK^T)

    • Scale by dk\sqrt{d_k} to prevent gradient instability

    • Softmax converts scores to probabilities (attention weights)

    • Weighted sum of values produces context-aware representations

  4. Parallelism: Unlike RNNs, attention processes all positions simultaneously via matrix operations, enabling both speed and long-range dependencies.

Next Up: L04 – Multi-Head Attention. One attention head is good, but it can only focus on one relationship at a time (e.g., “it” → “animal”). What if we also need to know that “animal” is the subject of the sentence? We need more heads!


Appendix A: Geometric View of Attention

This appendix provides a visual, step-by-step exploration of the attention mechanism using the same example from Part 4. If you prefer mathematical explanations, you can skip this section—it’s an alternative perspective on concepts already covered.

In the worked example in Part 4, we calculated that Q=[3, 1] attending to K_animal=[3, 1], K_street=[1, 4], and K_because=[1.5, 0.5] yields attention weights of 87%, 10%, and 3%. Let’s visualize this step-by-step using those exact vectors to see how “it” computes its final context vector.

We’ll break attention into its four steps:

  1. Similarity: Compute dot products Q·K (scores)

  2. Scaling: Divide by √d_k

  3. Softmax: Convert to probabilities (weights)

  4. Weighted Sum: Combine values using those weights

Step 1: Similarity in Q-K Space (Computing Scores)

<Figure size 750x900 with 1 Axes>

What this shows: The query “it” Q=[3, 1] (blue dashed arrow) compares against each key using the dot product. Notice that K_animal=[3, 1] perfectly aligns with Q (score=10), while K_street=[1, 4] points in a different direction (score=7). Circle size reflects the raw dot product scores—before any normalization.

Steps 2-3: Scaling and Softmax (Scores → Weights)

<Figure size 850x480 with 1 Axes>

What this shows: We divide each score by √d_k=√2≈1.41 to get “logits”, then apply softmax to convert them into probabilities that sum to 1.0. The result: “animal” gets 87% of the attention (from score=10), “street” gets 10% (from score=7), and “because” gets 3% (from score=5). These are the attention weights.

Step 4: Copy from Values (Weighted Sum in V-Space)

<Figure size 750x680 with 1 Axes>

What this shows: Values live in a different space than keys. Using the attention weights from Step 3, we compute a weighted average: context = 0.87×V_animal + 0.10×V_street + 0.03×V_because ≈ [1.78, 1.37]. The thick line to V_animal shows it dominates the contribution. This final context vector becomes the new representation for “it”—enriched with semantic content from “animal”.

Key Insight: Attention is a weighted average in value space, where the weights come from measuring similarity in query-key space. This is fundamentally different from a traditional database lookup:

This is why we need separate Q, K, V projections: K and V are different learned transformations of the same input. Keys determine how much to attend (semantic matching via geometric alignment), while Values determine what information to extract (the content to mix into the output). K[“river”] might encode “geographic term, noun, concrete” (optimized for matching), while V[“river”] encodes “flowing water, nature, geography” (semantic content to contribute).