知识蒸馏Skill knowledge-distillation

知识蒸馏是一种压缩大型语言模型的技术,通过从大型教师模型向小型学生模型传递知识,以在部署时保持高性能并降低推理成本。适用于模型压缩、能力迁移、成本优化和专业模型创建等场景。关键词:知识蒸馏,模型压缩,LLM,教师-学生模型,温度缩放,软目标,反向KLD,大语言模型,人工智能,深度学习。

大模型微调 0 次安装 0 次浏览 更新于 3/21/2026

名称: 知识蒸馏 描述: 使用从教师模型到学生模型的知识蒸馏来压缩大型语言模型。当需要部署较小模型以保持性能、将GPT-4能力迁移到开源模型或降低推理成本时使用。涵盖温度缩放、软目标、反向KLD、logit蒸馏和MiniLLM训练策略。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [新兴技术, 知识蒸馏, 模型压缩, 教师-学生模型, MiniLLM, 反向KLD, 软目标, 温度缩放, Logit蒸馏, 模型迁移] 依赖项: [transformers, torch, datasets]

知识蒸馏:压缩LLMs

何时使用此技能

使用知识蒸馏当您需要:

  • 压缩模型 从70B → 7B 同时保持90%+性能
  • 迁移能力 从专有模型(GPT-4)到开源模型(LLaMA, Mistral)
  • 降低推理成本 通过部署较小的学生模型
  • 创建专业模型 通过蒸馏领域特定知识
  • 改进小模型 使用来自大型教师的合成数据

关键技术: 温度缩放,软目标,反向KLD(MiniLLM),logit蒸馏,响应蒸馏

论文: Hinton et al. 2015 (arXiv 1503.02531), MiniLLM (arXiv 2306.08543), KD Survey (arXiv 2402.13116)

安装

# 标准transformers
pip install transformers datasets accelerate

# 用于训练
pip install torch deepspeed wandb

# 可选: MiniLLM实现
git clone https://github.com/microsoft/LMOps
cd LMOps/minillm
pip install -e .

快速开始

基本知识蒸馏

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

# 1. 加载教师(大)和学生(小)模型
教师 = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",  # 大型教师
    torch_dtype=torch.float16,
    device_map="auto"
)

学生 = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",  # 小型学生
    torch_dtype=torch.float16,
    device_map="cuda:0"
)

分词器 = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")

# 2. 定义蒸馏损失
def distillation_loss(学生_logits, 教师_logits, 标签, temperature=2.0, alpha=0.5):
    """
    结合硬损失(交叉熵)与软损失(KL散度)。

    参数:
        temperature: 软化概率分布(更高 = 更软)
        alpha: 蒸馏损失的权重(1-alpha用于硬损失)
    """
    # 硬损失: 与真实标签的标准交叉熵
    hard_loss = F.cross_entropy(学生_logits.view(-1, 学生_logits.size(-1)), 标签.view(-1))

    # 软损失: 学生与教师之间的KL散度
    soft_targets = F.softmax(教师_logits / temperature, dim=-1)
    soft_student = F.log_softmax(学生_logits / temperature, dim=-1)
    soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)

    # 组合损失
    return alpha * soft_loss + (1 - alpha) * hard_loss

# 3. 训练循环
for batch in dataloader:
    # 教师前向(无梯度)
    with torch.no_grad():
        teacher_outputs = 教师(**batch)
        teacher_logits = teacher_outputs.logits

    # 学生前向
    student_outputs = 学生(**batch)
    student_logits = student_outputs.logits

    # 计算蒸馏损失
    loss = distillation_loss(
        student_logits,
        teacher_logits,
        batch['labels'],
        temperature=2.0,
        alpha=0.7  # 70% 软, 30% 硬
    )

    # 反向传播和优化
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

MiniLLM(反向KLD)

来源: arXiv 2306.08543 (2024)

创新: 使用反向KLD代替前向KLD以获得更好的生成模型蒸馏。

def reverse_kl_loss(学生_logits, 教师_logits, temperature=1.0):
    """
    反向KL散度: KL(教师 || 学生)
    对于生成模型比前向KL更好。
    """
    # 教师分布(目标)
    p_teacher = F.softmax(教师_logits / temperature, dim=-1)

    # 学生分布(模型)
    log_p_student = F.log_softmax(学生_logits / temperature, dim=-1)

    # 反向KL: 在教师上求和,学生学习覆盖教师的所有模式
    reverse_kl = -(p_teacher * log_p_student).sum(dim=-1).mean()

    return reverse_kl * (temperature ** 2)

# 使用MiniLLM训练
for batch in dataloader:
    with torch.no_grad():
        teacher_logits = 教师(**batch).logits

    student_logits = 学生(**batch).logits

    # 反向KLD(对生成更好)
    loss = reverse_kl_loss(student_logits, teacher_logits, temperature=1.0)

    loss.backward()
    optimizer.step()

为什么用反向KL?

  • 前向KL(标准): 学生学习匹配教师的均值
  • 反向KL(MiniLLM): 学生学习覆盖教师的所有模式
  • 对多样化文本生成更好

响应蒸馏

# 从教师生成合成数据,训练学生模仿

# 1. 从教师生成合成响应
提示 = ["解释AI:", "什么是ML?", "定义NLP:"]

教师_响应 = []
for 提示 in 提示:
    inputs = 分词器(提示, return_tensors='pt').to(教师.device)
    outputs = 教师.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
    response = 分词器.decode(outputs[0], skip_special_tokens=True)
    教师_响应.append(response)

# 2. 在教师的响应上训练学生(标准微调)
train_dataset = [
    {"text": f"{提示}
{响应}"}
    for 提示, 响应 in zip(提示, 教师_响应)
]

# 3. 微调学生
trainer = Trainer(
    model=学生,
    args=TrainingArguments(output_dir="./学生", num_train_epochs=3, learning_rate=2e-5),
    train_dataset=train_dataset,
)
trainer.train()

核心概念

1. 温度缩放

目的: 软化概率分布以暴露教师的不确定性。

# 低温度 (T=1): 尖锐分布
logits = [3.0, 2.0, 1.0]
probs_T1 = softmax(logits / 1.0)  # [0.67, 0.24, 0.09]

# 高温度 (T=4): 软分布
probs_T4 = softmax(logits / 4.0)  # [0.42, 0.34, 0.24]

# 更高T揭示更多关于相对排名的信息

规则: 用于蒸馏时使用T=2-5(2是常见默认值)。

2. 损失函数组件

# 总损失 = alpha * soft_loss + (1 - alpha) * hard_loss

# 软损失: 从教师知识学习
soft_loss = KL(学生 || 教师)

# 硬损失: 从真实标签学习
hard_loss = CrossEntropy(学生_output, 真实_标签)

# 典型值:
alpha = 0.5  # 平衡
alpha = 0.7  # 更强调教师
alpha = 0.3  # 更强调标签

3. 前向 vs 反向KLD

# 前向KL: KL(学生 || 教师)
# - 学生匹配教师的平均行为
# - 模式寻求: 学生聚焦于教师的最高概率模式
# - 对分类好

# 反向KL: KL(教师 || 学生)
# - 学生覆盖教师的所有行为
# - 模式覆盖: 学生学习多样行为
# - 对生成好(MiniLLM)

训练策略

策略1: Logit蒸馏

# 训练学生直接匹配教师的logits

def logit_distillation_trainer(学生, 教师, dataloader, temperature=2.0):
    optimizer = torch.optim.AdamW(学生.parameters(), lr=2e-5)

    for epoch in range(3):
        for batch in dataloader:
            # 获取logits
            with torch.no_grad():
                teacher_logits = 教师(**batch).logits

            student_logits = 学生(**batch).logits

            # logits上的MSE(KLD的替代)
            loss = F.mse_loss(student_logits, teacher_logits)

            # 或使用KLD
            # loss = F.kl_div(
            #     F.log_softmax(student_logits/temperature, dim=-1),
            #     F.softmax(teacher_logits/temperature, dim=-1),
            #     reduction='batchmean'
            # ) * (temperature ** 2)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    return 学生

策略2: 两阶段蒸馏

# 阶段1: 从教师蒸馏
学生 = distill(教师, 学生, epochs=5)

# 阶段2: 在任务特定数据上微调
学生 = fine_tune(学生, task_data, epochs=3)

# 结果比单阶段有更好的任务性能

策略3: 多教师蒸馏

# 从多个专家教师学习

def multi_teacher_distillation(学生, 教师s, batch):
    """从教师集合蒸馏。"""
    teacher_logits_list = []

    # 从所有教师获取logits
    with torch.no_grad():
        for 教师 in 教师s:
            logits = 教师(**batch).logits
            teacher_logits_list.append(logits)

    # 平均教师预测
    avg_teacher_logits = torch.stack(teacher_logits_list).mean(dim=0)

    # 学生从集合学习
    student_logits = 学生(**batch).logits
    loss = F.kl_div(
        F.log_softmax(student_logits, dim=-1),
        F.softmax(avg_teacher_logits, dim=-1),
        reduction='batchmean'
    )

    return loss

生产部署

完整训练脚本

from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

def train_distilled_model(
    teacher_name="meta-llama/Llama-2-70b-hf",
    student_name="meta-llama/Llama-2-7b-hf",
    output_dir="./distilled-llama-7b",
    temperature=2.0,
    alpha=0.7,
):
    # 加载模型
    teacher = AutoModelForCausalLM.from_pretrained(teacher_name, torch_dtype=torch.float16, device_map="auto")
    student = AutoModelForCausalLM.from_pretrained(student_name, torch_dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(teacher_name)

    # 带有蒸馏的自定义训练器
    class DistillationTrainer(Trainer):
        def compute_loss(self, model, inputs, return_outputs=False):
            # 学生前向
            outputs_student = model(**inputs)
            student_logits = outputs_student.logits

            # 教师前向(无梯度)
            with torch.no_grad():
                outputs_teacher = teacher(**inputs)
                teacher_logits = outputs_teacher.logits

            # 蒸馏损失
            soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
            soft_student = F.log_softmax(student_logits / temperature, dim=-1)
            soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)

            # 硬损失
            hard_loss = outputs_student.loss

            # 组合
            loss = alpha * soft_loss + (1 - alpha) * hard_loss

            return (loss, outputs_student) if return_outputs else loss

    # 训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        warmup_steps=500,
        logging_steps=100,
        save_steps=1000,
        bf16=True,
        gradient_checkpointing=True,
    )

    # 训练
    trainer = DistillationTrainer(
        model=student,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    )

    trainer.train()
    student.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

# 使用
train_distilled_model(
    teacher_name="meta-llama/Llama-2-70b-hf",
    student_name="meta-llama/Llama-2-7b-hf",
    temperature=2.0,
    alpha=0.7
)

最佳实践

1. 超参数选择

# 温度
T = 1.0  # 尖锐(更少知识转移)
T = 2.0  # 标准(良好平衡)
T = 5.0  # 软(更多知识转移)

# Alpha(权重)
alpha = 0.5  # 平衡
alpha = 0.7  # 强调教师知识
alpha = 0.9  # 强蒸馏

# 规则: 更高T + 更高alpha = 更强蒸馏

2. 模型大小比率

# 好比率(教师/学生)
70B / 7B = 10×    # 优秀
13B / 1B = 13×    # 好
7B / 1B = 7×      # 可接受

# 避免太大差距
70B / 1B = 70×    # 太大,无效

3. 数据质量

# 最佳: 使用教师生成数据 + 真实数据
train_data = {
    "教师_生成": 70%,  # 多样化,高质量
    "真实_数据": 30%            # 真实情况
}

# 避免: 仅真实数据(未充分利用教师)

评估

from transformers import pipeline

# 比较学生vs教师
教师_pipe = pipeline("text-generation", model=教师)
学生_pipe = pipeline("text-generation", model=学生)

提示s = ["解释量子计算:", "什么是AI?"]

for 提示 in 提示s:
    teacher_out = 教师_pipe(提示, max_new_tokens=100)
    student_out = 学生_pipe(提示, max_new_tokens=100)

    print(f"提示: {提示}")
    print(f"教师: {teacher_out[0]['generated_text']}")
    print(f"学生: {student_out[0]['generated_text']}")
    print(f"匹配质量: {calculate_similarity(teacher_out, student_out):.2f}")

资源