How to Shrink Large Language Models Using Knowledge Distillation and Reinforcement Learning
Imagine you’ve built an incredibly smart AI model that can write text so well, it’s almost like it can read minds. But there’s a catch: it’s very expensive to run and too big to use easily in everyday applications, like answering user questions in real-time or working on small devices. This is where a technique called “knowledge distillation” comes in handy. It helps create a smaller, more efficient version of your big AI model without losing much of its smarts.
The DeepSeek-R1 vs. ChatGPT-o1 Situation
A recent development in the AI landscape has brought the concept of knowledge distillation into sharp focus. DeepSeek, a Chinese AI startup, unveiled its R1 model, which quickly climbed to the top of Apple’s App Store and Google’s Play Store, surpassing established models like OpenAI’s ChatGPT-o1. What set DeepSeek-R1 apart was its rapid development and cost-effectiveness, achieved through a method known as “distillation”. This process involves a new AI system learning from an existing one by analyzing responses to numerous questions, allowing for the creation of competitive AI models quickly and at much lower costs.
However, this approach has sparked significant controversy. OpenAI has claimed that DeepSeek used its proprietary technology to create a competing AI model, raising concerns about intellectual property theft. OpenAI suspects that DeepSeek employed distillation techniques to enhance its AI model by learning from larger, more advanced models, possibly breaching OpenAI’s terms of service.
Enhancing Knowledge Distillation with Reinforcement Learning
Incorporating reinforcement learning (RL) into the distillation process can further enhance the capabilities of distilled models. By integrating RL, the student model can learn not only to mimic the teacher model’s outputs but also to optimize its performance based on feedback, leading to more robust and efficient AI systems. This approach has been explored in various studies, highlighting the potential of combining knowledge distillation with reinforcement learning to achieve superior results.
In this article, I’ll walk you through knowledge distillation for GPT-2–style language models, so you can compress all that AI goodness into a cheaper, faster, and more deployable package.
NOTE: Using GPT-2 Medium to GPT-2 in the below example due of computational constraints 😞
Why Do We Distill Language Models?
If you’ve ever tried deploying a large language model, you know it’s not exactly a picnic. Here’s why you should care about distillation:
- Costs and Speed: Big models gulp down GPU memory and can rack up huge bills. Distilled models are typically smaller and require fewer flops to get answers out.
- Deployment Ease: Not every environment has a beefy GPU cluster on demand. Distilled models can often run on standard CPUs or smaller GPUs, letting you deploy in more places.
- Energy Efficiency: Smaller models mean less power consumption. Over time, that can add up to meaningful sustainability gains (and fewer panicked looks from your CFO).
The Essence of Knowledge Distillation
Distillation is basically the process of transferring that “dark knowledge” from a teacher model (the big one) into a student model (the smaller one). This doesn’t mean we simply train the smaller model on the same labeled dataset. Instead, we feed the student both:
- Ground Truth Labels (the standard cross-entropy loss with your actual data).
- Teacher Predictions (the teacher’s probability distribution over tokens — like how the teacher “thinks” about each token, not just the final answer).
We also introduce a temperature factor τ to “soften” the teacher’s distribution, so the student doesn’t just copy the teacher’s top-1 guess but also learns from the teacher’s nuanced sense of possibility across all tokens.
Our Setup: GPT-2 Medium → GPT-2
In a real project, your teacher could be GPT-o1 or any giant language model, and your student could be something more lightweight (like DistilGPT-2 or GPT-2 small). For simplicity, I’m demonstrating with:
- Teacher: GPT-2 Medium (~345M parameters).
- Student: GPT-2 (~124M parameters).
Note: This is still fairly large, so if your GPU or budget can’t handle it, try a smaller teacher (like GPT-2 or even DistilGPT-2) and an even smaller student.
Step 1: Setting Up and Loading Data
Let’s load WikiText-2 (raw) just as an example. We’ll filter out short lines (under 100 characters) because they tend not to provide much interesting context, then split into a training and validation set. Finally, we tokenize the data using the teacher’s tokenizer, making sure to pad or truncate each sample to a fixed length (MAX_LENGTH = 128
).
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments
)
from datasets import load_dataset
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
class Config:
TEACHER = "gpt2-medium"
STUDENT = "gpt2"
MAX_LENGTH = 128
ALPHA = 0.7
TEMP = 2.0
EPOCHS = 10
BATCH_SIZE = 2
LOGGING_STEPS = 10
PPO_EPOCHS = 10
PPO_MAX_NEW_TOKENS = 50
def load_and_prepare_data():
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
# Filter out short lines
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 100)
# Split into train/val
split_ds = raw_dataset.train_test_split(test_size=0.1, seed=42)
train_raw, val_raw = split_ds["train"], split_ds["test"]
# Subset for demonstration (avoid big memory usage)
train_raw = train_raw.select(range(min(len(train_raw), 1000)))
val_raw = val_raw.select(range(min(len(val_raw), 200)))
tokenizer = AutoTokenizer.from_pretrained(Config.TEACHER)
tokenizer.pad_token = tokenizer.eos_token # GPT-2 doesn't have a pad token
def tokenize_fn(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=Config.MAX_LENGTH
)
train_tok = train_raw.map(tokenize_fn, batched=True)
val_tok = val_raw.map(tokenize_fn, batched=True)
# For language modeling, labels = input_ids
def set_labels(example):
example["labels"] = example["input_ids"].copy()
return example
train_tok = train_tok.map(set_labels)
val_tok = val_tok.map(set_labels)
train_tok = train_tok.remove_columns(["text"])
val_tok = val_tok.remove_columns(["text"])
train_tok.set_format(type="torch")
val_tok.set_format(type="torch")
return train_tok, val_tok, tokenizer
train_data, val_data, tokenizer = load_and_prepare_data()
Step 2: Custom Trainer for Distillation
Hugging Face’s Trainer
is super convenient. By subclassing it, we can override how the loss is computed to incorporate both cross-entropy and KL-divergence.
class DistillationTrainer(Trainer):
def __init__(self, teacher, alpha, temp, *args, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher
self.alpha = alpha
self.temp = temp
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# Student forward pass
outputs_student = model(**inputs)
student_loss_ce = outputs_student.loss
student_logits = outputs_student.logits / self.temp
# Teacher forward pass (no grad)
with torch.no_grad():
outputs_teacher = self.teacher(**inputs)
teacher_logits = outputs_teacher.logits / self.temp
# KL divergence
loss_kl = F.kl_div(
F.log_softmax(student_logits, dim=-1),
F.softmax(teacher_logits, dim=-1),
reduction="batchmean"
) * (self.temp ** 2)
# Combined loss
loss = (1 - self.alpha) * student_loss_ce + self.alpha * loss_kl
return (loss, outputs_student) if return_outputs else loss
We then define a function to actually run the distillation:
def run_distillation(teacher, student, train_data, val_data, tokenizer):
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
per_device_train_batch_size=Config.BATCH_SIZE,
per_device_eval_batch_size=Config.BATCH_SIZE,
num_train_epochs=Config.EPOCHS,
logging_steps=Config.LOGGING_STEPS,
evaluation_strategy="epoch",
save_strategy="no",
fp16=False,
report_to="none"
)
trainer = DistillationTrainer(
teacher=teacher,
alpha=Config.ALPHA,
temp=Config.TEMP,
model=student,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
tokenizer=tokenizer,
)
trainer.train()
eval_metrics = trainer.evaluate()
print("[Distillation] Validation metrics:", eval_metrics)
return trainer
Step 3: Distill the Knowledge
Now we load the teacher and student models, place them on our GPU, and let the distillation process unfold.
teacher = AutoModelForCausalLM.from_pretrained(Config.TEACHER).to(device)
student = AutoModelForCausalLM.from_pretrained(Config.STUDENT).to(device)
trainer = run_distillation(teacher, student, train_data, val_data, tokenizer)
print("Distillation completed.")
What’s happening behind the scenes? On every training batch, the student sees both:
- Its own cross-entropy loss against the ground truth.
- KL-divergence showing how far it is from the teacher’s distribution.
Over time, the student starts to “think” a lot more like the teacher, often ending up with much better performance than a naive smaller model trained purely from scratch.
Step 4: Checking Perplexity
Perplexity is a common metric in language modeling. It effectively tells you how well the model predicts text:
PPL = exp(average cross — entropy loss).
Lower perplexity = better model.
def evaluate_perplexity(model, dataset, tokenizer):
model.eval()
total_loss = 0
total_batches = 0
batch_size = Config.BATCH_SIZE
for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
total_loss += outputs.loss.item()
total_batches += 1
avg_loss = total_loss / total_batches
ppl = np.exp(avg_loss)
return ppl
teacher_ppl = evaluate_perplexity(teacher, val_data, tokenizer)
student_ppl = evaluate_perplexity(student, val_data, tokenizer)
print(f"Teacher perplexity: {teacher_ppl:.2f}")
print(f"Student perplexity: {student_ppl:.2f}")
Quick PPO Fine-Tuning
We can do a simple Reinforcement Learning (RL) loop to fine-tune the student further. This example compares the student’s generated text with the teacher’s, using cosine similarity of their embeddings as a reward signal. Below is a simplified (and I do mean really simplified) PPO loop to align the student with the teacher’s embeddings:
class PPOTrainerCustom:
def __init__(self, student, teacher, tokenizer, device):
self.student = student
self.teacher = teacher
self.tokenizer = tokenizer
self.device = device
# Teacher has hidden_dim=1024, student=768; create a linear projection
self.teacher_proj = nn.Linear(1024, 768, bias=False).to(self.device)
def compute_reward(self, student_gen, teacher_gen):
with torch.no_grad():
student_emb = self.student.transformer.wte(student_gen)
teacher_emb = self.teacher.transformer.wte(teacher_gen)
teacher_emb_768 = self.teacher_proj(teacher_emb)
student_mean = student_emb.mean(dim=1)
teacher_mean = teacher_emb_768.mean(dim=1)
reward = F.cosine_similarity(student_mean, teacher_mean, dim=-1).mean()
return reward
def train(self, prompts, epochs=1):
optimizer = torch.optim.AdamW(self.student.parameters(), lr=5e-5)
all_losses = []
self.student.train()
self.teacher.eval()
for _ in range(epochs):
for prompt in prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
student_gen = self.student.generate(**inputs, max_new_tokens=Config.PPO_MAX_NEW_TOKENS)
# Teacher generation
with torch.no_grad():
teacher_gen = self.teacher.generate(**inputs, max_new_tokens=Config.PPO_MAX_NEW_TOKENS)
# Cosine similarity-based reward
reward = self.compute_reward(student_gen, teacher_gen)
logits = self.student(student_gen).logits
log_probs = F.log_softmax(logits, dim=-1)
last_token_idx = student_gen[:, -1].unsqueeze(-1)
last_token_logprob = log_probs[:, -1].gather(dim=-1, index=last_token_idx).mean()
# PPO-like loss
loss = -(last_token_logprob * reward)
optimizer.zero_grad()
loss.backward()
optimizer.step()
all_losses.append(loss.item())
return all_losses
A tiny example of usage:
ppo_trainer = PPOTrainerCustom(student, teacher, tokenizer, device=device)
test_prompts = ["The future of AI", "The capital of India is", "Machine learning is"]
ppo_losses = ppo_trainer.train(test_prompts, epochs=Config.PPO_EPOCHS)
print("Done with PPO Fine-Tuning!")
Important caveat: Real PPO is more complicated (involving advantage estimation, clipping, value functions, etc.). This toy snippet just shows how you could push your student’s behavior closer to your teacher’s in an RL manner.
Wrapping Up
Knowledge distillation is a game-changer when it comes to making large language models more practical. It allows you to keep a surprising amount of the teacher model’s wisdom, all while cutting down on hardware demands and response times.
- Why Distill? Because you get near-teacher performance, in a smaller, cheaper-to-run student model.
- How? By using a combination of cross-entropy (with real labels) and a KL-divergence that compares the student’s predictions to the teacher’s.
- Bonus Moves: Once your student is trained, you can tweak it further with RL or other domain-specific fine-tuning.
The code snippets above illustrate the core mechanics. In a production setting, you might scale up the dataset, train for longer, or tune hyperparameters like α and τ. But if you understand the essentials — distill, compare distributions, lighten your model — you’re well on your way to shipping a robust, efficient language model that fits within your budget and infrastructure constraints.
Now go forth and compress! Your users, your GPU bill, and the planet will thank you.
You can fine the entire code here: https://github.com/neuralsorcerer/LLM-Distillation
Check out the original paper on Knowledge Distillation: https://arxiv.org/abs/1503.02531
Thank you for reading! If you found this article helpful, feel free to share it with others.