Learn how to implement RLHF to align LLMs. We cover training reward models from first principles and the mechanics of PPO policy optimization.
Previously in this course, we explored Parameter-Efficient Fine-Tuning (LoRA) and Quantized LoRA (QLoRA) to adapt models to specific domains. While those methods excel at teaching models what to know, they don't inherently teach them how to behave according to human intent. That is where RLHF (Reinforcement Learning from Human Feedback) comes in.
In this lesson, we move beyond static supervised fine-tuning (SFT) to dynamic alignment. You’ll learn how to treat the model as an agent, define a reward signal, and optimize that agent using the Proximal Policy Optimization (PPO) algorithm.
Alignment is essentially a three-step process:
A reward model is typically a transformer-based regressor. Instead of predicting the next token, it takes a prompt and a response, then outputs a scalar score. We train it on a dataset of preferences—pairs of responses (chosen vs. rejected) for the same prompt—using a pairwise ranking loss:
$$Loss = -\log(\sigma(r_{chosen} - r_{rejected}))$$
This forces the model to assign a higher scalar value to the response preferred by humans.
Once the reward model is frozen, we use PPO to update our policy (the LLM). PPO is an "actor-critic" method. In our context:
PPO is preferred over vanilla policy gradients because it uses a "clipped" objective function, preventing the model from making massive, destructive updates to its weights during training.
In a production setting, you would use libraries like TRL (Transformer Reinforcement Learning). Here is the conceptual flow of the PPO update for a single step:
PYTHONimport torch import torch.nn.functional as F def ppo_step(model, ref_model, reward_model, batch, eps=0.2): # 1. Generate text from the current policy queries, responses = batch logprobs = get_logprobs(model, queries, responses) # 2. Get rewards from the frozen reward model rewards = reward_model(queries, responses) # 3. Calculate the ratio(pi_new / pi_old) ref_logprobs = get_logprobs(ref_model, queries, responses) ratio = torch.exp(logprobs - ref_logprobs) # 4. Compute the clipped objective # This prevents the policy from changing too drastically surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1.0 - eps, 1.0 + eps) * advantages loss = -torch.min(surr1, surr2).mean() return loss
When training with RLHF, you'll encounter a phenomenon called "Reward Hacking." If the reward model is imperfect, the policy model will find shortcuts to maximize the score—such as being overly polite or repeating specific "trigger" words that the reward model associates with high scores.
To combat this, we include a KL Divergence penalty. We calculate the KL divergence between the output distribution of our current policy and the initial SFT model. If the policy drifts too far from the original, we apply a penalty to the reward:
$$Reward_{final} = Reward_{model} - \beta \cdot KL(Policy_{curr} || Policy_{SFT})$$
Your task is to implement the KL penalty calculation. Using the torch.distributions module, calculate the KL divergence between two sets of logits:
log_probs_curr tensor and log_probs_ref tensor.kl_div = (log_probs_curr - log_probs_ref).mean().Alignment via RLHF is the bridge between raw capability and human utility. By training a reward model to interpret preferences and using PPO to constrain policy updates, we can steer LLMs toward safer and more helpful behaviors. Remember that the reward model is the "source of truth"—if it's biased, your model's alignment will be biased too.
Up next: We will explore Direct Preference Optimization (DPO), a modern alternative to RLHF that simplifies the alignment process by removing the need for an explicit reward model and PPO.
Learn how to implement DPO to align LLMs without a reward model. Master the DPO training loop, compare it to RLHF, and optimize your model's preferences.
Read moreMaster Mixture-of-Experts (MoE) layers to build scalable, compute-efficient LLMs. Learn to design expert routers, implement sparse layers, and balance load.
Alignment with RLHF