Master activation checkpointing to train massive models by trading redundant compute for memory. Learn to implement selective recomputation in your PyTorch pipelines.
Previously in this course, we explored Tensor and Pipeline Parallelism: Scaling Large Model Training to distribute model state across GPU fleets. While those strategies address model parameter storage, they do not solve the "activation explosion" that occurs during the forward pass of deep networks. This lesson adds a critical layer of memory optimization: Activation Checkpointing, also known as gradient checkpointing, which allows you to fit significantly larger models or batch sizes into the same hardware footprint.
In standard backpropagation, the framework must store all intermediate activations (the output of each layer) generated during the forward pass because they are required to calculate gradients during the backward pass. For a transformer with $L$ layers, this creates a memory footprint that scales linearly with depth and sequence length ($O(L \times N_{seq})$).
Activation Checkpointing breaks this dependency. Instead of storing every activation, we store only a subset of "checkpoints" (e.g., at the input of each transformer block). When the backward pass reaches a point where an activation is missing, the model re-runs the forward pass for that specific segment using the saved checkpoint.
While libraries like torch.utils.checkpoint offer a black-box approach, production-grade training requires selective checkpointing. You shouldn't checkpoint everything; you should target the most memory-intensive layers (the attention heads and feed-forward networks).
Here is a simplified implementation of a checkpoint-aware Transformer block:
PYTHONimport torch import torch.utils.checkpoint as checkpoint class CheckpointedTransformerBlock(torch.nn.Module): def __init__(self, block): super().__init__() self.block = block def forward(self, x, *args, **kwargs): # We define a function that runs the block's internal logic def custom_forward(*inputs): return self.block(*inputs) # checkpoint() saves inputs and runs custom_forward during backward # use_reentrant=False is the modern, recommended approach return checkpoint.checkpoint( custom_forward, x, *args, use_reentrant=False, **kwargs )
The decision to checkpoint is a sliding scale. In large-scale training, we categorize layers by their activation size.
| Layer Type | Memory Usage | Checkpoint Priority |
|---|---|---|
| Attention Projections | Low | Low |
| QKV/Softmax O(N^2) | Very High | Critical |
| Feed-Forward (MLP) | Moderate | Medium |
| Layer Norms | Minimal | None |
For a standard LLM architecture, we prioritize checkpointing the Attention Softmax and the MLP intermediate states. These represent the bulk of the activations. By only checkpointing these, you achieve near-optimal memory savings while minimizing the recomputation overhead.
When scaling to billions of parameters, you must integrate checkpointing with your distributed strategy. If you are already using Quantized LoRA (QLoRA): Fine-tuning Massive Models on Consumer GPUs or standard DDP, ensure your checkpointing implementation does not conflict with distributed synchronization primitives.
Pro-tip: Always use use_reentrant=False in PyTorch's checkpoint function if you are using modern PyTorch (2.0+). The older re-entrant mode creates issues with torch.compile and certain optimizer states, leading to subtle bugs in gradient accumulation.
torch.cuda.max_memory_allocated().CheckpointedTransformerBlock defined above.checkpoint versions that require re-entrant logic can cause gradient discrepancies when using non-deterministic operations or custom autograd functions.checkpoint will re-run the forward pass during the backward phase. If your dropout state is not properly managed, you might get different dropout masks, leading to unstable gradients. Always ensure the random seed is managed or use deterministic dropout if needed.Activation checkpointing is your primary tool for fitting large models into limited VRAM. By selectively recomputing activations, you trade a marginal increase in compute time for a dramatic reduction in memory, enabling the training of deeper, wider, or longer-context models than would otherwise be possible on your hardware.
Up next: We will dive into Mixed Precision Training (FP8/BF16), where we reduce the precision of our tensors to further slash memory usage and accelerate throughput.
Learn to scale models beyond single-GPU memory limits. Master Tensor Parallelism, Pipeline Parallelism, and activation checkpointing for efficient training.
Read moreMaster Mixture-of-Experts (MoE) layers to build scalable, compute-efficient LLMs. Learn to design expert routers, implement sparse layers, and balance load.
Advanced Activation Checkpointing