Master advanced Weight Initialization in PyTorch. Learn to control gradient flow and stabilize deep network training using custom variance-scaling techniques.
Previously in this course, we explored the lifecycle of Project Initialization: Defining the Machine Learning Prediction Problem. While that lesson focused on the business and data logic of starting a model, today we move into the "plumbing" of deep learning. Specifically, we will look at how the initial values of your weights dictate whether your network learns at all or immediately collapses into vanishing or exploding gradients.
If you initialize weights from a standard normal distribution $\mathcal{N}(0, 1)$, the variance of the activations grows (or shrinks) exponentially as you move through layers. In a deep network, this leads to the "dying gradient" or "exploding gradient" problem.
To maintain gradient flow, we want the variance of the activations and the variance of the gradients to remain consistent across layers. This is the core principle behind modern Weight Initialization strategies.
If we have an output $y = Wx$, the variance of $y$ is related to the variance of $W$ and $x$. For a layer with $n_{in}$ inputs, we want the variance of our output to be equal to the variance of our input: $$Var(y) = n_{in} \cdot Var(w) \cdot Var(x) = Var(x)$$ This implies $Var(w) = 1 / n_{in}$. This is the intuition behind Xavier (Glorot) initialization. However, Xavier assumes linear activations. When we use non-linearities like ReLU, which zeros out half the input space, we effectively halve the variance. This is why we need Kaiming (He) initialization, which adjusts for the gain of the activation function.
PyTorch provides torch.nn.init, but in production, you often need to implement custom gain factors for non-standard activation functions (like SwiGLU or custom Gated Linear Units).
Here is a concrete example of a custom Kaiming-style initializer that allows you to pass a specific gain factor based on your architecture's activation function:
PYTHONimport torch import torch.nn as nn import math def custom_kaiming_init(module, a=0, mode=CE9178">'fan_in', nonlinearity=CE9178">'leaky_relu'): CE9178">""" Custom initialization applying variance scaling with a specific gain. CE9178">'a' is the negative slope of the rectifier used. """ if isinstance(module, (nn.Linear, nn.Conv2d)): # Calculate the gain based on the activation function gain = nn.init.calculate_gain(nonlinearity, a) # Standard Kaiming calculation # fan_in: use the number of input units # fan_out: use the number of output units nn.init.kaiming_normal_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) if module.bias is not None: nn.init.constant_(module.bias, 0) # Applying to our running project's model model = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 128) ) model.apply(lambda m: custom_kaiming_init(m, nonlinearity=CE9178">'relu'))
The goal of these strategies is to keep the signal variance stable. If your initialization is too "small," the signal dies within 5-10 layers. If it's too "large," the gradients explode.
Pro-tip: In production, use a forward hook to monitor activation statistics during your first few training steps. If the mean of your activations shifts significantly away from 0 or the variance collapses, your initialization strategy is mismatched with your architecture.
torch.randn (standard normal) for one run and your custom_kaiming_init for the second.randn version drift toward zero, while the kaiming version remains stable.nn.init.calculate_gain('relu') but your layer uses tanh, you are effectively scaling your weights by $\approx 1.73$ unnecessarily. This often leads to training instability in the first few epochs.Proper Weight Initialization is the difference between a model that converges in hours and one that never learns. By scaling variance using the fan_in or fan_out of your layers and applying the correct activation gain, you preserve the signal through the deepest parts of your network.
Up next: Normalization Techniques at Scale, where we move from static initialization to dynamic activation control using RMSNorm and LayerNorm.
Master Residual Connections to prevent vanishing gradients. Learn to architect stable ResNet blocks and implement identity mapping for deep, scalable models.
Read moreLearn how to implement gradient accumulation to simulate large batch sizes on memory-constrained hardware and maintain training stability with effective LR scaling.