Learn how to implement ZeRO-3 optimization to shard optimizer states across nodes. Master distributed training memory efficiency for massive LLMs.
Previously in this course, we explored Mixed Precision Training (FP8/BF16): A Practitioner's Guide to reduce memory bandwidth and storage requirements. In this lesson, we address the "memory wall" that remains even after precision tuning: the massive overhead of optimizer states, gradients, and parameters during the training of large-scale models.
When training with standard Data Parallelism (like PyTorch's DistributedDataParallel discussed in Data Parallelism Strategies: Scaling PyTorch with DDP), every GPU keeps a full copy of the model parameters, gradients, and—most critically—the optimizer states.
If you use Adam, your optimizer stores two additional parameters for every model weight (the first and second moments). For a 7B parameter model, these states alone consume 56 GB of VRAM, even before accounting for the model parameters themselves. ZeRO (Zero Redundancy Optimizer) eliminates this waste by partitioning these states across the available GPUs.
The ZeRO optimization suite is divided into three stages, each increasing the degree of memory reduction:
In ZeRO-3, the model's parameters are not replicated on every GPU. Instead, each GPU is responsible for only a slice of the parameter tensor. During the forward and backward passes, the necessary parameter shards are gathered (all-gathered) on-the-fly and then discarded immediately after use.
DeepSpeed is the industry-standard implementation for ZeRO. To integrate it into your training loop, you must wrap your model and optimizer. Unlike manual DDP, DeepSpeed handles the orchestration of the communication collectives required to fetch parameter slices.
PYTHONimport deepspeed import torch # Define your model as usual model = MyLargeTransformer() # DeepSpeed configuration ds_config = { "train_batch_size": 32, "zero_optimization": { "stage": 3, "overlap_comm": True, "contiguous_gradients": True, "sub_group_size": 1e9, "reduce_bucket_size": 5e8, }, "fp16": {"enabled": True} } # Initialize DeepSpeed engine model_engine, optimizer, _, _ = deepspeed.initialize( model=model, model_parameters=model.parameters(), config=ds_config ) # Training loop for batch in data_loader: inputs, labels = batch outputs = model_engine(inputs) loss = criterion(outputs, labels) # DeepSpeed handles the backward pass and gradient sharding model_engine.backward(loss) model_engine.step()
To understand the impact, let $P$ be the number of parameters and $N$ be the number of GPUs.
By distributing the states, your memory consumption scales inversely with the number of devices. This is the primary reason why we can train models with hundreds of billions of parameters on limited GPU clusters—the memory requirement per node becomes constant regardless of the total model size.
deepspeed installed and a multi-GPU environment available.ds_config.json file. Set stage to 3.torch.cuda.memory_summary() before and after initializing the model with deepspeed.initialize.nn.DataParallel setup versus the DeepSpeed ZeRO-3 setup for a model with at least 1B parameters. You should observe a significant drop in "Reserved Memory" on individual cards.cpu_offload in the DeepSpeed config. This moves optimizer states to system RAM, trading speed for massive capacity.save_checkpoint method provided by DeepSpeed, as the model is physically split across nodes. Loading a standard PyTorch state dict into a ZeRO-3 engine will fail.ZeRO-3 is the cornerstone of modern distributed training. By sharding optimizer states, gradients, and parameters, we move from memory-bound training to compute-bound training. When combined with Tensor and Pipeline Parallelism: Scaling Large Model Training, you can theoretically scale to any model size provided you have the aggregate GPU memory in your cluster.
Up next: Gradient Accumulation and Batch Sizing — we will discuss how to simulate massive batch sizes when memory constraints still limit your per-device throughput.
Learn to scale models beyond single-GPU memory limits. Master Tensor Parallelism, Pipeline Parallelism, and activation checkpointing for efficient training.
Read moreMaster activation checkpointing to train massive models by trading redundant compute for memory. Learn to implement selective recomputation in your PyTorch pipelines.
Distributed Optimizer States