Master Mixed Precision training with BF16 and FP8. Learn how to implement loss scaling, ensure numerical stability, and accelerate deep learning workloads.
Previously in this course, we explored Advanced Activation Checkpointing: Memory Optimization for Deep Learning to trade compute for memory. While checkpointing reduces memory footprint, it doesn't inherently accelerate the raw arithmetic throughput of your model. This lesson adds Mixed Precision—the technique of using lower-precision formats like BF16 and FP8 to shrink the memory footprint of activations and weights while significantly increasing FLOPS on modern GPU architectures.
Standard deep learning models traditionally rely on FP32 (32-bit floating point). While numerically robust, FP32 is computationally expensive and memory-heavy. Mixed Precision training (introduced in the context of Tensor and Pipeline Parallelism: Scaling Large Model Training) involves using lower-precision formats for the bulk of the computation—specifically matrix multiplications—while maintaining master copies of weights in FP32 to preserve optimization stability.
| Format | Bits | Range | Precision | Primary Use Case |
|---|---|---|---|---|
| FP32 | 32 | High | High | Master weights, accumulation |
| BF16 | 16 | Wide | Low | Training stability (similar to FP32) |
| FP8 | 8 | Narrow | Very Low | Maximum throughput (H100+) |
The "Brain Floating Point" (BF16) format shares the same exponent range as FP32, making it highly resistant to the overflow/underflow issues common in standard FP16. FP8, conversely, is an aggressive format supported by newer architectures (like NVIDIA Hopper) that requires careful management of dynamic ranges via scaling factors.
In PyTorch, we leverage the torch.amp (Automatic Mixed Precision) package. This handles casting operations, choosing appropriate precision for specific layers, and managing the loss scaling process automatically.
Here is how you implement a robust training loop using torch.amp.autocast:
PYTHONimport torch # Assume model and optimizer are already defined scaler = torch.cuda.amp.GradScaler() # Required for FP16, optional for BF16 def train_step(model, data, target, optimizer): optimizer.zero_grad() # Use autocast to automatically cast operations to BF16 with torch.amp.autocast(device_type=CE9178">'cuda', dtype=torch.bfloat16): output = model(data) loss = criterion(output, target) # Scales the loss to prevent underflow in gradients scaler.scale(loss).backward() # Unscales gradients and calls optimizer.step() scaler.step(optimizer) scaler.update()
While BF16 is generally stable, you may encounter "NaN" gradients when working with aggressive FP8 or legacy FP16 training. This is typically caused by gradients falling below the representable range of the floating-point format (underflow).
Loss Scaling works by multiplying the loss by a large factor (the scale) before computing gradients, effectively shifting the gradient values into the representable range. After the backward pass, we unscale them before the optimizer update.
NaN or Inf early in training, your loss scale is likely too high (causing overflow) or too low (causing underflow).torch.cuda.is_bf16_supported() to gate your implementation.Modify your current training loop in the course project to incorporate torch.amp.
torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16).GradScaler to your training loop.nvidia-smi or torch.cuda.memory_allocated()).autocast context managers. Manually casting layers can lead to subtle bugs where operations that require FP32 precision (like LayerNorm or Softmax) are cast down, causing numerical collapse.Inf values. GradScaler does this automatically; avoid writing your own unless you have specific, non-standard requirements.Mixed Precision training is a standard requirement for efficient LLM training. By using BF16 for activations and FP32 for master weights, you balance throughput and stability. As you advance, remember that hardware-specific formats like FP8 provide the next frontier for speed, provided your loss scaling logic remains robust.
Up next: Distributed Optimizer States, where we will look at how to shard optimizer states across multiple GPUs to train models that exceed the memory capacity of a single device.
Learn how to implement gradient accumulation to simulate large batch sizes on memory-constrained hardware and maintain training stability with effective LR scaling.
Read moreLearn how to implement magnitude-based pruning to remove redundant weights, evaluate sparsity impact, and fine-tune pruned models for production efficiency.
Mixed Precision Training (FP8/BF16)