Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

L03 - Self-Attention: The Search Engine of Language

Building the “brain” of the Transformer: How words talk to each other.


In L02 - Embeddings, we turned words into vectors and gave them positions. But there is still a fatal flaw: each word gets the same vector every time, regardless of context.

Consider the word “Bank”.

  1. “The bank of the river.” (Nature)

  2. “The bank approved the loan.” (Finance)

In a static embedding layer (a simple lookup table), the vector for “bank” is identical in both sentences. The model sees the same token ID → same vector, regardless of whether “bank” appears near “river” or “loan”.

<Figure size 1400x600 with 2 Axes>

The Problem Visualized:

Notice the red square labeled “bank (static)” is at the EXACT SAME POSITION in both plots. Whether “bank” appears in a sentence about rivers or loans, the static embedding lookup table returns the identical vector.

What We Need (shown by dashed arrows):

Self-Attention is the mechanism that enables this context-dependent shift—it allows “bank” to dynamically adjust its representation based on surrounding words.

By the end of this post, you’ll understand:


Part 1: The Intuition (The Filing Cabinet)

Solving the “Bank” Problem:

We’ve seen that static embeddings give “bank” the same vector whether it appears with “river” or “loan”. How does attention fix this? It allows each word to look at its neighbors and adjust its meaning based on what it finds.

The math of Attention can look scary, but the concept is simple. It is a Soft Database Lookup.

Imagine every word in the sentence is a folder in a filing cabinet. To facilitate a search, every word produces three vectors:

VectorNameRoleAnalogy
QQueryWhat I am looking for?A sticky note I hold up: “I am looking for adjectives describing me.”
KKeyWhat do I contain?The label on the folder: “I am an adjective.”
VValueThe contentThe actual document inside the folder: “Blue.”

How Attention Works: The Search Process

Let’s trace how “bank” would use attention to shift its meaning in “The bank of the river”:

  1. “bank” generates its Query: “What context am I in?”

  2. It compares this Query against every other word’s Key:

    • “river” Key: “I’m a geographic feature” → High match!

    • “of” Key: “I’m a preposition” → Low match

    • “the” Key: “I’m an article” → Low match

  3. “bank” finds the highest match with “river”

  4. It extracts the Value from “river” (its semantic meaning) and adds it to its own representation

Now, the vector for “bank” is no longer just the static embedding; it is “bank + a lot of ‘river’ + a little bit of ‘the’ and ‘of’”. The representation has shifted toward the nature/geography meaning!

The Key Advantage: Everything Happens at Once

Remember from L02: “The attention mechanism is parallel. It looks at every word in a sentence at the exact same time.”

This is the breakthrough that makes Transformers faster AND better at understanding language than older architectures like RNNs (Recurrent Neural Networks).

How RNNs Process a Sentence (Sequential):

RNNs maintain a “hidden state”—a vector that accumulates information from all previous words. At each step, the hidden state combines the current word with everything seen so far.

Let’s trace a concrete example: pronoun resolution. When the model processes “it”, how does it figure out that “it” refers to “bank”? We’ll follow how information flows through the hidden states.

Input: "The bank approved the loan because it was well-capitalized"
        ↑    ↑                              ↑
      word 1  word 2                       word 7

Step 1: "The"      → hidden_state_1 = f(embedding("The"))
                      ↓ Contains: [info about "The"]

Step 2: "bank"     → hidden_state_2 = f(embedding("bank"), hidden_state_1)
                      ↓ Contains: [info about "The", "bank"] compressed into one vector

Step 3: "approved" → hidden_state_3 = f(embedding("approved"), hidden_state_2)
                      ↓ Contains: [info about "The", "bank", "approved"] compressed

...

Step 7: "it"       → hidden_state_7 = f(embedding("it"), hidden_state_6)
                      ↓ Contains: [ALL 7 words] compressed into a fixed-size vector

Problem: To understand what "it" refers to, information about "bank" (word 2)
must pass through a chain of compressions before reaching "it" (word 7):

  hidden_state_2 (contains "bank" info)
    ↓ compressed with "approved"
  hidden_state_3
    ↓ compressed with "the"
  hidden_state_4
    ↓ compressed with "loan"
  hidden_state_5
    ↓ compressed with "because"
  hidden_state_6 (used by "it" at step 7)

Information about "bank" has been compressed through 4 intermediate mixing steps.
The more steps between "bank" and "it", the more diluted the information becomes.
This is the "vanishing gradient" problem.

Total: 7 sequential steps (MUST run one-by-one)

The Hidden State Bottleneck:

The hidden state is accumulative (it tries to remember everything), but it achieves this through compression into a fixed-size vector (typically 512 or 1024 dimensions).

Think of it like this: After reading “The bank approved the loan because it”, the RNN must squeeze all understanding of these 7 words—their meanings, relationships, syntactic roles—into a single vector of fixed size. Then it must use this compressed summary to process “was” and “well-capitalized”.

It’s like trying to fit an entire Wikipedia article into a tweet, then using only that tweet to write the next paragraph. Some information inevitably gets lost or diluted.

How Attention Processes the Same Sentence (Parallel):

Now let’s see how attention solves the same pronoun resolution task: when “it” needs to figure out what it refers to.

Instead of passing information through a chain of hidden states, attention allows direct connections between any two words.

Input: "The bank approved the loan because it was well-capitalized"

Single Step: ALL words computed simultaneously via matrix operations:

"it" (word 7) compares its Query directly against ALL Keys:
  Q("it") · K("The")      = 0.05  → Low attention weight
  Q("it") · K("bank")     = 0.82  → HIGH attention weight! ✓
  Q("it") · K("approved") = 0.08  → Low attention weight
  Q("it") · K("the")      = 0.02  → Low attention weight
  Q("it") · K("loan")     = 0.15  → Medium attention weight
  ...

Result: "it" can look DIRECTLY at "bank" (word 2) without any intermediate steps.
No information loss. No compression. Direct access.

Every other word does the same computation simultaneously:
  - "The" attends to all words
  - "bank" attends to all words
  - "approved" attends to all words
  ... (all computed in parallel)

Total: 1 parallel step (all comparisons at once, then weighted sum)

Why Direct Access Matters:

  1. No information loss: “it” doesn’t need to hope that information about “bank” survived 4 intermediate compressions (hidden states 3→4→5→6)

  2. Long-range dependencies: Works just as well for word 100 referring to word 1 as for adjacent words

  3. Symmetry: “bank” can attend to “loan” just as easily as “loan” attends to “bank”

  4. Multiple relationships: Each word can attend strongly to multiple other words simultaneously (through the weighted sum)

Concrete Example - Pronoun Resolution:

Consider: “The artist gave the musician a score because she loved her composition.”

Why This Works:

The attention mechanism achieves parallelism through matrix multiplication. When we compute QKTQK^T (which you’ll see shortly), we’re not looping through words one-by-one. Instead:

  1. Every word generates its Query, Key, and Value simultaneously (one matrix operation)

  2. Every Query compares against every Key simultaneously (another matrix operation)

  3. Every word gets its context-aware representation simultaneously (final matrix operation)

Modern GPUs are optimized for matrix operations, so computing attention for 100 words in parallel is barely slower than computing it for 10 words. This is why Transformers can handle such long contexts efficiently.

Now let’s see the math that makes this parallelism possible.


Part 2: The Math of Similarity

In Part 1, we said that each word’s Query compares against other words’ Keys to find matches. But HOW do we “compare” two vectors? How do we measure if a Query is similar to a Key?

The answer: the Dot Product.

The dot product is a mathematical operation that measures alignment between two vectors. If two vectors point in the same direction, the result is large and positive. If they point in opposite directions, it is negative. If they’re perpendicular, the result is zero.

This is how attention computes “relevance”: Q(“it”) · K(“bank”) gives us a score measuring how much “it” should attend to “bank”.

Visualizing the “Magnitude Problem”

However, there is a catch. The dot product captures both alignment and magnitude. Before we look at the formula, let’s visualize why this is dangerous for neural networks.

In the plot below, we compare a Query (Blue) against three different Keys.

<Figure size 1000x600 with 1 Axes>

Why this is a problem: Look at the score for K3 (Misaligned). It scored 7.0. Now look at K1 (Short). It scored 5.0.

The “Misaligned” vector beat the “Perfectly Aligned” vector simply because it was longer. If we don’t fix this, our model will prioritize “loud” signals (large numbers) over “correct” signals (aligned meaning).

We fix this by Scaling: we divide the result by the square root of the dimension (dk\sqrt{d_k}). This normalizes the scores so the model focuses on alignment, not magnitude.

The Formula

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Let’s break down the equation step-by-step:

  1. The Scores (QKTQK^T): We multiply the Query of the current word by the Keys of all words.

  2. The Scaling (dk\sqrt{d_k}): We shrink the scores to prevent exploding values.

  3. Softmax (The Probability): We convert scores into probabilities that sum to 1.0.

  4. The Weighted Sum (VV): We multiply the probabilities by the Values to get the final context vector.

Example Walkthrough: Crunching the Numbers

Let’s trace the math using vectors from the plot above. We’ll use the pronoun resolution example: when “it” (query) attends to “animal”, “street”, and “because” (keys).

Step 1: The Dot Product (QKTQK^T) - Computing Raw Scores

Step 2: Scaling (dk\sqrt{d_k}) We divide by 21.41\sqrt{2} \approx 1.41.

Step 3: Softmax We exponentiate and normalize to get percentages using the Softmax formula:

P(xi)=exiexjP(x_i) = \frac{e^{x_i}}{\sum e^{x_j}}

Notice how the mechanism successfully identified the aligned vector (“animal”) as the important one, giving it 87% of the attention! This matches what we’ll see in the geometric visualization below.

Geometric View: The Four Steps of Attention

In the worked example above, we calculated that Q=[3, 1] attending to K_animal=[3, 1], K_street=[1, 4], and K_because=[1.5, 0.5] yields attention weights of 87%, 10%, and 3%. Let’s visualize this step-by-step using those exact vectors to see how “it” computes its final context vector.

We’ll break attention into its four steps:

  1. Similarity: Compute dot products Q·K (scores)

  2. Scaling: Divide by √d_k

  3. Softmax: Convert to probabilities (weights)

  4. Weighted Sum: Combine values using those weights

Step 1: Similarity in Q-K Space (Computing Scores)

<Figure size 750x680 with 1 Axes>

What this shows: The query “it” Q=[3, 1] (blue dashed arrow) compares against each key using the dot product. Notice that K_animal=[3, 1] perfectly aligns with Q (score=10), while K_street=[1, 4] points in a different direction (score=7). Arrow thickness reflects the raw dot product scores—before any normalization.

Steps 2-3: Scaling and Softmax (Scores → Weights)

<Figure size 850x480 with 1 Axes>

What this shows: We divide each score by √d_k=√2≈1.41 to get “logits”, then apply softmax to convert them into probabilities that sum to 1.0. The result: “animal” gets 87% of the attention (from score=10), “street” gets 10% (from score=7), and “because” gets 3% (from score=5). These are the attention weights.

Step 4: Copy from Values (Weighted Sum in V-Space)

<Figure size 750x680 with 1 Axes>

What this shows: Values live in a different space than keys. Using the attention weights from Step 3, we compute a weighted average: context = 0.87×V_animal + 0.10×V_street + 0.03×V_because ≈ [1.79, 1.44]. The thick line to V_animal shows it dominates the contribution. This final context vector becomes the new representation for “it”—enriched with semantic content from “animal”.

Key Insight: Attention is a weighted average in value space, where the weights come from measuring similarity in query-key space. This is why we need separate Q, K, V projections—keys determine how much to attend (pronoun resolution via geometric alignment), but values determine what information to extract (semantic content).


Part 3: Visualizing the Attention Map

We’ve explored attention through different lenses—from the “bank” disambiguation problem to pronoun resolution with the geometric view above. Now let’s see the full attention pattern for our pronoun resolution example as a heatmap.

In trained models, attention patterns emerge that capture semantic relationships. The heatmap below shows a simplified example to illustrate what we might expect: brighter colors represent higher attention weights (post-softmax probabilities).

<Figure size 1000x800 with 2 Axes>

Part 4: Implementation in PyTorch

We’ve seen the intuition (Q/K/V filing cabinet), the math (dot products and scaling), and the visualization (attention heatmaps). Now let’s see how remarkably simple the actual code is.

We can implement this entire mechanism in fewer than 20 lines of code.

Note the use of masked_fill, which we will use in L06 - The Causal Mask to prevent the model from “cheating” by looking at future words.

import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k

    def forward(self, q, k, v, mask=None):
        # 1. Calculate the Dot Product (Scores)
        # q: [batch, heads, seq, d_k]
        # k.transpose: [batch, heads, d_k, seq] -> We flip the last two dimensions
        # scores shape: [batch, heads, seq, seq]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 2. Apply Mask (Optional - vital for GPT!)
        if mask is not None:
            # We use a very large negative number so Softmax turns it to zero
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 3. Softmax to get probabilities (0.0 to 1.0)
        attn_weights = torch.softmax(scores, dim=-1)
        
        # 4. Multiply by Values to get the weighted context
        output = torch.matmul(attn_weights, v)
        
        return output, attn_weights


Summary

  1. Context Matters: Standard embeddings are static. Attention makes them dynamic.

  2. Q, K, V: We project our input into “Queries” (Searches), “Keys” (Labels), and “Values” (Content).

  3. Scaling: We divide by dk\sqrt{d_k} to stop the gradients from vanishing when vectors get large.

Next Up: L04 – Multi-Head Attention. One attention head is good, but it can only focus on one relationship at a time (e.g., “it” → “animal”). What if we also need to know that “animal” is the subject of the sentence? We need more heads!