Master the Attention Mechanism by implementing Multi-Head Attention from scratch. Learn to code scaled dot-product attention and causal masks in PyTorch.
Previously in this course, we covered Normalization Techniques at Scale: Implementing RMSNorm, providing the foundation for stable activations. In this lesson, we move to the heart of the Transformer architecture: the Attention Mechanism. Specifically, we will implement Multi-Head Attention (MHA) from first principles, ensuring you understand the mechanics of projection, scaling, and causal masking.
At its core, the attention mechanism allows a model to dynamically focus on different parts of an input sequence. We represent this as a query ($Q$), key ($K$), and value ($V$) retrieval system. For a given query, we compute similarity scores against all keys, normalize these scores, and use them as weights to produce a weighted sum of values.
The mathematical formulation for Scaled Dot-Product Attention is: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
We scale by $\sqrt{d_k}$ to prevent the dot product from growing too large in magnitude, which would push the softmax function into regions with extremely small gradients, effectively "killing" the learning process.
Instead of performing a single attention function, we project $Q, K,$ and $V$ into $h$ subspaces. This allows the model to attend to information from different representation sub-spaces simultaneously (e.g., one head might focus on syntax, while another focuses on semantic relationships).
In practice, we don't create $h$ separate matrices. We perform one large linear projection to $3 \times d_{model}$ and then reshape the result to separate the heads.
Let's build the core component in PyTorch. We need to handle tensor shapes carefully: $[batch, heads, seq_len, head_dim]$.
PYTHONimport torch import torch.nn as nn import torch.nn.functional as F import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads == 0 self.d_k = d_model // num_heads self.num_heads = num_heads # Single linear layer for all projections self.qkv_proj = nn.Linear(d_model, 3 * d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x, mask=None): batch_size, seq_len, d_model = x.size() # 1. Project and split into heads qkv = self.qkv_proj(x) qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k) q, k, v = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, d_k] # 2. Scaled Dot-Product Attention attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: attn_scores = attn_scores.masked_fill(mask == 0, float(CE9178">'-inf')) attn_probs = F.softmax(attn_scores, dim=-1) output = (attn_probs @ v) # 3. Concatenate heads and project back output = output.transpose(1, 2).reshape(batch_size, seq_len, d_model) return self.out_proj(output)
In decoder-only models (like GPT), we must ensure that a token cannot "see" future tokens in the sequence. We achieve this by applying a lower-triangular mask to the attention scores before the softmax operation.
PYTHONdef create_causal_mask(seq_len): # Creates a mask where positions above the diagonal are set to 0 mask = torch.tril(torch.ones(seq_len, seq_len)) return mask.view(1, 1, seq_len, seq_len)
MultiHeadAttention class to include a dropout parameter after the softmax layer. This is standard in production models to prevent overfitting.(2, 10, 512) and pass it through your module with a causal mask. Verify that the output shape remains (2, 10, 512).TransformerBlock class that combines your MultiHeadAttention with the RMSNorm implementation from our previous lesson..transpose() or .permute() explicitly and verify shapes using print(tensor.shape) during development.(seq_len, seq_len), it needs to be reshaped to (1, 1, seq_len, seq_len) to correctly interact with the (batch, heads, seq, seq) attention score matrix.We have successfully implemented the backbone of the Transformer. By using scaled dot-product attention, we allow the model to learn complex relationships; by using multi-head projections, we allow it to learn these relationships in parallel; and by using causal masks, we enable autoregressive sequence generation. You now have the fundamental building block required for the next stage of our architecture.
Up next: Positional Encoding Architectures, where we will explore how to inject sequence order information into these permutation-invariant attention layers.
Master the implementation of a production-ready Transformer architecture in PyTorch. Learn to define robust configuration schemas and initialize model weights.
Read moreMaster the Transformer encoder-decoder architecture. Learn to implement cross-attention and build complete Seq2Seq models for production-grade AI applications.
Implementing Multi-Head Attention