Learn to scale models beyond single-GPU memory limits. Master Tensor Parallelism, Pipeline Parallelism, and activation checkpointing for efficient training.
Previously in this course, we explored Data Parallelism Strategies: Scaling PyTorch with DDP. While Data Parallelism is excellent for scaling throughput by replicating the model, it fails when your model is too large to fit into the memory (VRAM) of a single GPU.
In this lesson, we move beyond data replication to model partitioning. We’ll cover how to split your model across devices using Tensor and Pipeline Parallelism, and how to use activation checkpointing to trade compute for memory.
When a model’s parameters, gradients, and optimizer states exceed the VRAM of a single GPU, you must employ Model Parallelism. Unlike Data Parallelism, where each GPU holds a full copy of the model, Model Parallelism divides the model's structure itself.
Tensor Parallelism splits individual layers across multiple GPUs. For a dense layer $Y = XA$, we can partition the weight matrix $A$ along its columns ($A = [A_1, A_2]$). Each GPU computes a partial output ($Y_1 = XA_1, Y_2 = XA_2$) and then synchronizes via an AllGather operation. This is highly effective for large Transformer layers, such as the attention projections or MLP blocks.
Pipeline Parallelism partitions the model vertically by layers. If you have a 48-layer transformer, you might place layers 1-12 on GPU 0, 13-24 on GPU 1, and so on. This creates a "pipeline" where data flows from one device to the next.
Pipeline Parallelism introduces the "bubble" problem: if GPU 1 is waiting for GPU 0 to finish its forward pass, GPU 1 sits idle. We mitigate this by splitting the batch into smaller "micro-batches," allowing multiple GPUs to work on different parts of the pipeline simultaneously.
In modern distributed systems, DeepSpeed provides a robust abstraction for these patterns. Here is how you define a simple pipeline stage:
PYTHONimport torch import torch.nn as nn from deepspeed.pipe import PipelineModule, LayerSpec # Define your model as a list of layers layers = [ LayerSpec(nn.Linear, 1024, 1024), LayerSpec(nn.ReLU), LayerSpec(nn.Linear, 1024, 1024) ] # Partition the model across GPUs model = PipelineModule( layers=layers, num_stages=2, partition_method=CE9178">'parameters' )
Even with parallelism, storing activations for backpropagation consumes massive amounts of VRAM. Activation Checkpointing (or gradient checkpointing) solves this by discarding intermediate activations during the forward pass and recomputing them on-the-fly during the backward pass.
This reduces the memory complexity from $O(L)$ to $O(\sqrt{L})$ where $L$ is the number of layers.
PYTHONfrom torch.utils.checkpoint import checkpoint class TransformerBlock(nn.Module): def __init__(self, layer): super().__init__() self.layer = layer def forward(self, x): # Instead of self.layer(x), use checkpoint return checkpoint(self.layer, x)
torch.distributed process group.torch.nn.Sequential to wrap your layers.cuda:0 and the last two to cuda:1. Implement a forward pass that moves tensors between devices using .to('cuda:x').torch.utils.checkpoint and monitor VRAM usage using torch.cuda.memory_allocated() before and after the change.micro_batch_size * num_micro_batches matches your total batch size to keep the pipeline full.By combining these techniques with the data-parallel strategies we discussed previously, you can train models that are orders of magnitude larger than your hardware's physical constraints.
Up next: Efficient Dataset Loading and Prefetching — we'll ensure your data pipeline keeps up with your newly scaled model.
Learn how to implement ZeRO-3 optimization to shard optimizer states across nodes. Master distributed training memory efficiency for massive LLMs.
Read moreMaster activation checkpointing to train massive models by trading redundant compute for memory. Learn to implement selective recomputation in your PyTorch pipelines.
Tensor and Pipeline Parallelism