Attention Is All You Need: A Walkthrough

Breaking down the core ideas behind the transformer architecture — self-attention, positional encoding, and multi-head attention — with equations and implementation snippets.

## table of contents

Self-Attention Mechanism

The key innovation of transformers is the self-attention mechanism. Given an input sequence, self-attention computes a weighted sum of all positions, where the weights are determined by the compatibility of each pair of positions.

For an input matrix XRn×dX \in \mathbb{R}^{n \times d}, we compute queries, keys, and values:

Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^V

The attention output is then:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

The dk\sqrt{d_k} scaling factor prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients.

Multi-Head Attention

Rather than performing a single attention function, transformers use multi-head attention to jointly attend to information from different representation subspaces:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O

where each head is computed as:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

Implementation

Here’s a minimal PyTorch implementation of scaled dot-product attention:

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / d_k ** 0.5

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V)

Architecture Overview

flowchart TB
    IN["Input Tokens"] --> EMB["Embedding"]
    EMB --> PE["+ Positional Encoding"]

    PE --> Q["Q"] & K["K"] & V["V"]

    Q & K & V --> H1["Head 1"]
    Q & K & V --> H2["Head 2"]
    Q & K & V --> Hh["Head h"]

    H1 & H2 & Hh --> CAT["Concat + Project"]

    CAT --> ADD1(("+"))
    PE -.->|residual| ADD1
    ADD1 --> LN1["Layer Norm"]

    LN1 --> FF["Feed-Forward Network"]
    FF --> ADD2(("+"))
    LN1 -.->|residual| ADD2
    ADD2 --> LN2["Layer Norm"]

    LN2 --> OUT["Output"]

Positional Encoding

Since transformers have no inherent notion of position, we add positional encodings to the input embeddings. The original paper uses sinusoidal functions:

PE(pos,2i)=sin(pos100002i/d)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i+1)=cos(pos100002i/d)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)

The beauty of this encoding is that it allows the model to learn relative positions, since PEpos+kPE_{pos+k} can be represented as a linear function of PEposPE_{pos}.

Key Takeaways

  1. Self-attention has O(n2)O(n^2) complexity with sequence length
  2. Multi-head attention allows the model to focus on different aspects simultaneously
  3. Positional encodings inject sequence order information
  4. Layer normalization and residual connections are critical for training stability

This post is a simplified walkthrough. For the full details, see Vaswani et al., 2017.