Master Knowledge Distillation to transfer intelligence from massive teacher models to efficient student models, optimizing your AI systems for production.
Previously in this course, we explored Model Pruning Techniques: Reducing Size and Increasing Latency to remove redundant parameters from our neural networks. While pruning focuses on zeroing out existing weights, Distillation takes a different approach: it trains a smaller, "student" architecture to mimic the behavioral output of a pre-trained, high-capacity "teacher" model.
Distillation is arguably the most powerful tool for Model Compression when you need to maintain high accuracy on constrained hardware. By forcing a student model to learn not just the ground-truth labels, but the "dark knowledge" contained in the teacher's soft probability distributions, we achieve higher efficiency than training the student from scratch.
In standard supervised learning, a model learns to map inputs to hard labels (e.g., one-hot vectors). However, the teacher model’s output logits contain valuable information about the relationships between classes. For example, in a classification task, a teacher might indicate that an image of a "dog" is 90% "dog," 9% "cat," and 1% "car." The 9% "cat" signal tells the student that the features of a dog are somewhat similar to a cat, but very different from a car.
Distillation captures this via a modified loss function:
$$L_{total} = \alpha L_{distill} + (1 - \alpha) L_{student}$$
By introducing temperature $T > 1$, we flatten the probability distribution, exposing more of the "dark knowledge" that would otherwise be hidden in the low-probability tails.
To implement this, we need a custom loss function that handles the soft targets from the teacher. We use a high temperature for both the teacher and the student to normalize their logits before calculating the KL-Divergence.
PYTHONimport torch import torch.nn as nn import torch.nn.functional as F def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5): CE9178">""" Implements the distillation loss combining KL Divergence and Cross Entropy. """ # Distillation loss: KL Divergence between soft distributions soft_teacher = F.log_softmax(teacher_logits / T, dim=1) soft_student = F.softmax(student_logits / T, dim=1) distill_loss = F.kl_div(soft_teacher, soft_student, reduction=CE9178">'batchmean') * (T**2) # Standard cross-entropy loss student_loss = F.cross_entropy(student_logits, labels) return alpha * distill_loss + (1 - alpha) * student_loss
Note the multiplication by $T^2$. When we divide logits by $T$, the gradients produced by the soft targets scale by $1/T^2$. Multiplying by $T^2$ ensures that the relative contribution of the distillation loss remains consistent when we change the temperature.
The training loop for distillation is similar to standard training, but you must keep the teacher model in eval() mode and ensure you don't compute gradients for its parameters.
PYTHON# Assuming teacher and student are pre-defined models teacher.eval() student.train() optimizer = torch.optim.Adam(student.parameters(), lr=1e-4) for inputs, labels in dataloader: optimizer.zero_grad() with torch.no_grad(): teacher_logits = teacher(inputs) student_logits = student(inputs) loss = distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.7) loss.backward() optimizer.step()
distillation_loss function above to implement a dynamic temperature schedule. Start with a high $T$ (e.g., 5.0) and decay it to 1.0 over the course of training.Knowledge Distillation is a powerful technique for Efficiency in production ML. By leveraging the teacher's soft probability distributions, we provide the student with richer information than hard labels alone. This process—balancing the distillation loss and the standard cross-entropy loss—allows us to deploy high-performing models on hardware that would otherwise be unable to run the original teacher architecture.
Up next: We will discuss how to deploy these models using optimized inference runtimes like vLLM, which further enhances the serving throughput of our distilled student models.
Learn how to implement magnitude-based pruning to remove redundant weights, evaluate sparsity impact, and fine-tune pruned models for production efficiency.
Read moreLearn how to implement gradient accumulation to simulate large batch sizes on memory-constrained hardware and maintain training stability with effective LR scaling.
Knowledge Distillation