Mahamudul Hasan Rubel
HomeBlogCoursesAboutProjectsSkillsExperiencePhotosContact
Mahamudul Hasan Rubel

Senior Software Engineer crafting high-performance web applications and SaaS platforms.

Navigation

  • Home
  • Blog
  • Courses
  • About
  • Projects
  • Skills
  • Experience
  • Photos
  • Contact

Get in Touch

Available for senior/lead roles and consulting.

bd.mhrubel@gmail.comHire Me

Subscribe to the newsletter

Get new articles and course lessons delivered to your inbox. No spam, unsubscribe anytime.

© 2026 Mahamudul Hasan Rubel. All rights reserved.

Built with using Next.js 16 & Tailwind v4

Back to Blog
Lesson 6 of the Advanced AI/ML: Deep Learning, LLMs & Production Systems course
AI/MLJune 27, 20264 min read

Implementing Multi-Head Attention: A Deep Dive into Transformers

Master the Attention Mechanism by implementing Multi-Head Attention from scratch. Learn to code scaled dot-product attention and causal masks in PyTorch.

Attention MechanismTransformerSelf-AttentionPyTorchDeep Learningaimachine-learningpython

Previously in this course, we covered Normalization Techniques at Scale: Implementing RMSNorm, providing the foundation for stable activations. In this lesson, we move to the heart of the Transformer architecture: the Attention Mechanism. Specifically, we will implement Multi-Head Attention (MHA) from first principles, ensuring you understand the mechanics of projection, scaling, and causal masking.

The Attention Mechanism: First Principles

At its core, the attention mechanism allows a model to dynamically focus on different parts of an input sequence. We represent this as a query ($Q$), key ($K$), and value ($V$) retrieval system. For a given query, we compute similarity scores against all keys, normalize these scores, and use them as weights to produce a weighted sum of values.

The mathematical formulation for Scaled Dot-Product Attention is: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

We scale by $\sqrt{d_k}$ to prevent the dot product from growing too large in magnitude, which would push the softmax function into regions with extremely small gradients, effectively "killing" the learning process.

Managing Multi-Head Projections

Instead of performing a single attention function, we project $Q, K,$ and $V$ into $h$ subspaces. This allows the model to attend to information from different representation sub-spaces simultaneously (e.g., one head might focus on syntax, while another focuses on semantic relationships).

In practice, we don't create $h$ separate matrices. We perform one large linear projection to $3 \times d_{model}$ and then reshape the result to separate the heads.

Implementing Scaled Dot-Product Attention

Let's build the core component in PyTorch. We need to handle tensor shapes carefully: $[batch, heads, seq_len, head_dim]$.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        
        # Single linear layer for all projections
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        # 1. Project and split into heads
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        q, k, v = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, d_k]

        # 2. Scaled Dot-Product Attention
        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float(CE9178">'-inf'))
            
        attn_probs = F.softmax(attn_scores, dim=-1)
        output = (attn_probs @ v)
        
        # 3. Concatenate heads and project back
        output = output.transpose(1, 2).reshape(batch_size, seq_len, d_model)
        return self.out_proj(output)

Implementing Causal Masking

In decoder-only models (like GPT), we must ensure that a token cannot "see" future tokens in the sequence. We achieve this by applying a lower-triangular mask to the attention scores before the softmax operation.

PYTHON
def create_causal_mask(seq_len):
    # Creates a mask where positions above the diagonal are set to 0
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.view(1, 1, seq_len, seq_len)

Hands-on Exercise

  1. Modify the MultiHeadAttention class to include a dropout parameter after the softmax layer. This is standard in production models to prevent overfitting.
  2. Create a dummy input tensor of shape (2, 10, 512) and pass it through your module with a causal mask. Verify that the output shape remains (2, 10, 512).
  3. Project Advancement: Define a TransformerBlock class that combines your MultiHeadAttention with the RMSNorm implementation from our previous lesson.

Common Pitfalls

  • Forgetting the Scale Factor: Omitting the $1/\sqrt{d_k}$ scaling is the most common error. Without it, the softmax output becomes "peaky," concentrating all probability mass on the single highest score, which ruins training stability.
  • Permutation Errors: When dealing with multi-head tensors, keeping track of dimensions is crucial. Always use .transpose() or .permute() explicitly and verify shapes using print(tensor.shape) during development.
  • Masking Logic: Ensure your mask is broadcastable. If your mask is (seq_len, seq_len), it needs to be reshaped to (1, 1, seq_len, seq_len) to correctly interact with the (batch, heads, seq, seq) attention score matrix.

Recap

We have successfully implemented the backbone of the Transformer. By using scaled dot-product attention, we allow the model to learn complex relationships; by using multi-head projections, we allow it to learn these relationships in parallel; and by using causal masks, we enable autoregressive sequence generation. You now have the fundamental building block required for the next stage of our architecture.

Up next: Positional Encoding Architectures, where we will explore how to inject sequence order information into these permutation-invariant attention layers.

Previous lessonGating Units and Activation FunctionsNext lesson Positional Encoding Architectures
Back to Blog

Similar Posts

AI/MLJune 27, 20263 min read

Project Milestone: Custom Transformer Architecture Setup

Master the implementation of a production-ready Transformer architecture in PyTorch. Learn to define robust configuration schemas and initialize model weights.

Read more
AI/MLJune 27, 20263 min read

Transformer Encoder-Decoder Design: Building Seq2Seq Models

Master the Transformer encoder-decoder architecture. Learn to implement cross-attention and build complete Seq2Seq models for production-grade AI applications.

Part of the course

Advanced AI/ML: Deep Learning, LLMs & Production Systems

advanced · Lesson 6 of 48

  1. 1

    Advanced Weight Initialization Strategies

    4 min
  2. 2

    Normalization Techniques at Scale

    3 min
  3. 3

    High-Dimensional Optimization Landscapes

    4 min
Read more
AI/MLJune 28, 20264 min read

Mixture-of-Experts (MoE) Layers: Scaling Efficiently with Sparsity

Master Mixture-of-Experts (MoE) layers to build scalable, compute-efficient LLMs. Learn to design expert routers, implement sparse layers, and balance load.

Read more
  • 4

    Residual Connections and Gradient Stability

    4 min
  • 5

    Gating Units and Activation Functions

    4 min
  • 6

    Implementing Multi-Head Attention

    4 min
  • 7

    Positional Encoding Architectures

    4 min
  • 8

    Transformer Encoder-Decoder Design

    3 min
  • 9

    Project Milestone: Custom Architecture Setup

    3 min
  • 10

    Tokenization Strategies for LLMs

    3 min
  • 11

    Scaling Laws and Compute Budgets

    4 min
  • 12

    Data Parallelism Strategies

    3 min
  • 13

    Tensor and Pipeline Parallelism

    4 min
  • 14

    Efficient Dataset Loading and Prefetching

    4 min
  • 15

    Fine-tuning Methodologies Overview

    4 min
  • 16

    Parameter-Efficient Fine-Tuning (LoRA)

    4 min
  • 17

    Quantized LoRA (QLoRA)

    4 min
  • 18

    Alignment with RLHF

    4 min
  • 19

    Direct Preference Optimization (DPO)

    4 min
  • 20

    Project Milestone: Domain-Specific Fine-Tuning

    3 min
  • 21

    Vector Databases and Similarity Search

    4 min
  • 22

    Retrieval Strategies for RAG

    3 min
  • 23

    Context Management and Windowing

    4 min
  • 24

    Agentic Tool Use and Function Calling

    4 min
  • 25

    Chain-of-Thought and Multi-Step Reasoning

    4 min
  • 26

    Self-Correction and Iterative Refinement

    4 min
  • 27

    Project Milestone: RAG and Agent Integration

    3 min
  • 28

    Post-Training Quantization (PTQ)

    4 min
  • 29

    Model Pruning Techniques

    4 min
  • 30

    Knowledge Distillation

    4 min
  • 31

    Optimized Inference Runtimes (vLLM)

    4 min
  • 32

    TensorRT-LLM for High-Performance Serving

    3 min
  • 33

    ONNX Runtime for Cross-Platform Inference

    3 min
  • 34

    Project Milestone: Inference Optimization

    3 min
  • 35

    CI/CD for ML (MLOps)

    4 min
  • 36

    Continuous Training (CT) Pipelines

    4 min
  • 37

    Observability and Logging

    4 min
  • 38

    Drift Detection and Data Monitoring

    4 min
  • 39

    LLM-as-a-Judge for Evaluation

    4 min
  • 40

    Scaling Deployments with Kubernetes

    4 min
  • 41

    GPU Resource Allocation and Scheduling

    3 min
  • 42

    Project Milestone: Production Deployment

    3 min
  • 43

    Advanced Activation Checkpointing

    4 min
  • 44

    Mixed Precision Training (FP8/BF16)

    4 min
  • 45

    Distributed Optimizer States

    4 min
  • 46

    Gradient Accumulation and Batch Sizing

    4 min
  • 47

    Multi-Modal Model Architectures

    4 min
  • 48

    Mixture-of-Experts (MoE) Layers

    4 min
  • View full course