名称: 知识蒸馏 描述: 使用从教师模型到学生模型的知识蒸馏来压缩大型语言模型。当需要部署较小模型以保持性能、将GPT-4能力迁移到开源模型或降低推理成本时使用。涵盖温度缩放、软目标、反向KLD、logit蒸馏和MiniLLM训练策略。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [新兴技术, 知识蒸馏, 模型压缩, 教师-学生模型, MiniLLM, 反向KLD, 软目标, 温度缩放, Logit蒸馏, 模型迁移] 依赖项: [transformers, torch, datasets]
知识蒸馏:压缩LLMs
何时使用此技能
使用知识蒸馏当您需要:
- 压缩模型 从70B → 7B 同时保持90%+性能
- 迁移能力 从专有模型(GPT-4)到开源模型(LLaMA, Mistral)
- 降低推理成本 通过部署较小的学生模型
- 创建专业模型 通过蒸馏领域特定知识
- 改进小模型 使用来自大型教师的合成数据
关键技术: 温度缩放,软目标,反向KLD(MiniLLM),logit蒸馏,响应蒸馏
论文: Hinton et al. 2015 (arXiv 1503.02531), MiniLLM (arXiv 2306.08543), KD Survey (arXiv 2402.13116)
安装
# 标准transformers
pip install transformers datasets accelerate
# 用于训练
pip install torch deepspeed wandb
# 可选: MiniLLM实现
git clone https://github.com/microsoft/LMOps
cd LMOps/minillm
pip install -e .
快速开始
基本知识蒸馏
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
# 1. 加载教师(大)和学生(小)模型
教师 = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf", # 大型教师
torch_dtype=torch.float16,
device_map="auto"
)
学生 = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", # 小型学生
torch_dtype=torch.float16,
device_map="cuda:0"
)
分词器 = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
# 2. 定义蒸馏损失
def distillation_loss(学生_logits, 教师_logits, 标签, temperature=2.0, alpha=0.5):
"""
结合硬损失(交叉熵)与软损失(KL散度)。
参数:
temperature: 软化概率分布(更高 = 更软)
alpha: 蒸馏损失的权重(1-alpha用于硬损失)
"""
# 硬损失: 与真实标签的标准交叉熵
hard_loss = F.cross_entropy(学生_logits.view(-1, 学生_logits.size(-1)), 标签.view(-1))
# 软损失: 学生与教师之间的KL散度
soft_targets = F.softmax(教师_logits / temperature, dim=-1)
soft_student = F.log_softmax(学生_logits / temperature, dim=-1)
soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)
# 组合损失
return alpha * soft_loss + (1 - alpha) * hard_loss
# 3. 训练循环
for batch in dataloader:
# 教师前向(无梯度)
with torch.no_grad():
teacher_outputs = 教师(**batch)
teacher_logits = teacher_outputs.logits
# 学生前向
student_outputs = 学生(**batch)
student_logits = student_outputs.logits
# 计算蒸馏损失
loss = distillation_loss(
student_logits,
teacher_logits,
batch['labels'],
temperature=2.0,
alpha=0.7 # 70% 软, 30% 硬
)
# 反向传播和优化
loss.backward()
optimizer.step()
optimizer.zero_grad()
MiniLLM(反向KLD)
来源: arXiv 2306.08543 (2024)
创新: 使用反向KLD代替前向KLD以获得更好的生成模型蒸馏。
def reverse_kl_loss(学生_logits, 教师_logits, temperature=1.0):
"""
反向KL散度: KL(教师 || 学生)
对于生成模型比前向KL更好。
"""
# 教师分布(目标)
p_teacher = F.softmax(教师_logits / temperature, dim=-1)
# 学生分布(模型)
log_p_student = F.log_softmax(学生_logits / temperature, dim=-1)
# 反向KL: 在教师上求和,学生学习覆盖教师的所有模式
reverse_kl = -(p_teacher * log_p_student).sum(dim=-1).mean()
return reverse_kl * (temperature ** 2)
# 使用MiniLLM训练
for batch in dataloader:
with torch.no_grad():
teacher_logits = 教师(**batch).logits
student_logits = 学生(**batch).logits
# 反向KLD(对生成更好)
loss = reverse_kl_loss(student_logits, teacher_logits, temperature=1.0)
loss.backward()
optimizer.step()
为什么用反向KL?
- 前向KL(标准): 学生学习匹配教师的均值
- 反向KL(MiniLLM): 学生学习覆盖教师的所有模式
- 对多样化文本生成更好
响应蒸馏
# 从教师生成合成数据,训练学生模仿
# 1. 从教师生成合成响应
提示 = ["解释AI:", "什么是ML?", "定义NLP:"]
教师_响应 = []
for 提示 in 提示:
inputs = 分词器(提示, return_tensors='pt').to(教师.device)
outputs = 教师.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
response = 分词器.decode(outputs[0], skip_special_tokens=True)
教师_响应.append(response)
# 2. 在教师的响应上训练学生(标准微调)
train_dataset = [
{"text": f"{提示}
{响应}"}
for 提示, 响应 in zip(提示, 教师_响应)
]
# 3. 微调学生
trainer = Trainer(
model=学生,
args=TrainingArguments(output_dir="./学生", num_train_epochs=3, learning_rate=2e-5),
train_dataset=train_dataset,
)
trainer.train()
核心概念
1. 温度缩放
目的: 软化概率分布以暴露教师的不确定性。
# 低温度 (T=1): 尖锐分布
logits = [3.0, 2.0, 1.0]
probs_T1 = softmax(logits / 1.0) # [0.67, 0.24, 0.09]
# 高温度 (T=4): 软分布
probs_T4 = softmax(logits / 4.0) # [0.42, 0.34, 0.24]
# 更高T揭示更多关于相对排名的信息
规则: 用于蒸馏时使用T=2-5(2是常见默认值)。
2. 损失函数组件
# 总损失 = alpha * soft_loss + (1 - alpha) * hard_loss
# 软损失: 从教师知识学习
soft_loss = KL(学生 || 教师)
# 硬损失: 从真实标签学习
hard_loss = CrossEntropy(学生_output, 真实_标签)
# 典型值:
alpha = 0.5 # 平衡
alpha = 0.7 # 更强调教师
alpha = 0.3 # 更强调标签
3. 前向 vs 反向KLD
# 前向KL: KL(学生 || 教师)
# - 学生匹配教师的平均行为
# - 模式寻求: 学生聚焦于教师的最高概率模式
# - 对分类好
# 反向KL: KL(教师 || 学生)
# - 学生覆盖教师的所有行为
# - 模式覆盖: 学生学习多样行为
# - 对生成好(MiniLLM)
训练策略
策略1: Logit蒸馏
# 训练学生直接匹配教师的logits
def logit_distillation_trainer(学生, 教师, dataloader, temperature=2.0):
optimizer = torch.optim.AdamW(学生.parameters(), lr=2e-5)
for epoch in range(3):
for batch in dataloader:
# 获取logits
with torch.no_grad():
teacher_logits = 教师(**batch).logits
student_logits = 学生(**batch).logits
# logits上的MSE(KLD的替代)
loss = F.mse_loss(student_logits, teacher_logits)
# 或使用KLD
# loss = F.kl_div(
# F.log_softmax(student_logits/temperature, dim=-1),
# F.softmax(teacher_logits/temperature, dim=-1),
# reduction='batchmean'
# ) * (temperature ** 2)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return 学生
策略2: 两阶段蒸馏
# 阶段1: 从教师蒸馏
学生 = distill(教师, 学生, epochs=5)
# 阶段2: 在任务特定数据上微调
学生 = fine_tune(学生, task_data, epochs=3)
# 结果比单阶段有更好的任务性能
策略3: 多教师蒸馏
# 从多个专家教师学习
def multi_teacher_distillation(学生, 教师s, batch):
"""从教师集合蒸馏。"""
teacher_logits_list = []
# 从所有教师获取logits
with torch.no_grad():
for 教师 in 教师s:
logits = 教师(**batch).logits
teacher_logits_list.append(logits)
# 平均教师预测
avg_teacher_logits = torch.stack(teacher_logits_list).mean(dim=0)
# 学生从集合学习
student_logits = 学生(**batch).logits
loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
F.softmax(avg_teacher_logits, dim=-1),
reduction='batchmean'
)
return loss
生产部署
完整训练脚本
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
def train_distilled_model(
teacher_name="meta-llama/Llama-2-70b-hf",
student_name="meta-llama/Llama-2-7b-hf",
output_dir="./distilled-llama-7b",
temperature=2.0,
alpha=0.7,
):
# 加载模型
teacher = AutoModelForCausalLM.from_pretrained(teacher_name, torch_dtype=torch.float16, device_map="auto")
student = AutoModelForCausalLM.from_pretrained(student_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
# 带有蒸馏的自定义训练器
class DistillationTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
# 学生前向
outputs_student = model(**inputs)
student_logits = outputs_student.logits
# 教师前向(无梯度)
with torch.no_grad():
outputs_teacher = teacher(**inputs)
teacher_logits = outputs_teacher.logits
# 蒸馏损失
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)
# 硬损失
hard_loss = outputs_student.loss
# 组合
loss = alpha * soft_loss + (1 - alpha) * hard_loss
return (loss, outputs_student) if return_outputs else loss
# 训练参数
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-5,
warmup_steps=500,
logging_steps=100,
save_steps=1000,
bf16=True,
gradient_checkpointing=True,
)
# 训练
trainer = DistillationTrainer(
model=student,
args=training_args,
train_dataset=train_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
student.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# 使用
train_distilled_model(
teacher_name="meta-llama/Llama-2-70b-hf",
student_name="meta-llama/Llama-2-7b-hf",
temperature=2.0,
alpha=0.7
)
最佳实践
1. 超参数选择
# 温度
T = 1.0 # 尖锐(更少知识转移)
T = 2.0 # 标准(良好平衡)
T = 5.0 # 软(更多知识转移)
# Alpha(权重)
alpha = 0.5 # 平衡
alpha = 0.7 # 强调教师知识
alpha = 0.9 # 强蒸馏
# 规则: 更高T + 更高alpha = 更强蒸馏
2. 模型大小比率
# 好比率(教师/学生)
70B / 7B = 10× # 优秀
13B / 1B = 13× # 好
7B / 1B = 7× # 可接受
# 避免太大差距
70B / 1B = 70× # 太大,无效
3. 数据质量
# 最佳: 使用教师生成数据 + 真实数据
train_data = {
"教师_生成": 70%, # 多样化,高质量
"真实_数据": 30% # 真实情况
}
# 避免: 仅真实数据(未充分利用教师)
评估
from transformers import pipeline
# 比较学生vs教师
教师_pipe = pipeline("text-generation", model=教师)
学生_pipe = pipeline("text-generation", model=学生)
提示s = ["解释量子计算:", "什么是AI?"]
for 提示 in 提示s:
teacher_out = 教师_pipe(提示, max_new_tokens=100)
student_out = 学生_pipe(提示, max_new_tokens=100)
print(f"提示: {提示}")
print(f"教师: {teacher_out[0]['generated_text']}")
print(f"学生: {student_out[0]['generated_text']}")
print(f"匹配质量: {calculate_similarity(teacher_out, student_out):.2f}")
资源
- Hinton et al. 2015 (基础): https://arxiv.org/abs/1503.02531
- MiniLLM (反向KLD): https://arxiv.org/abs/2306.08543
- LLMs的KD调查 (2024): https://arxiv.org/abs/2402.13116
- MiniLLM GitHub: https://github.com/microsoft/LMOps/tree/main/minillm