名称: 模型剪枝 描述: 使用如Wanda和SparseGPT等剪枝技术减小大型语言模型(LLM)尺寸并加速推理。适用于无需重新训练的模型压缩,实现50%稀疏度且精度损失最小,或在硬件加速器上实现更快的推理。涵盖非结构化剪枝、结构化剪枝、N:M稀疏度、幅度剪枝和一次性方法。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [新兴技术, 模型剪枝, Wanda, SparseGPT, 稀疏度, 模型压缩, N:M稀疏度, 一次性剪枝, 结构化剪枝, 非结构化剪枝, 快速推理] 依赖项: [transformers, torch]
模型剪枝:压缩LLMs
何时使用此技能
使用模型剪枝当您需要:
- 减小模型尺寸 40-60%,精度损失小于1%
- 加速推理 使用硬件友好的稀疏度(2-4倍加速)
- 部署在受限硬件上(移动设备、边缘设备)
- 无需重新训练压缩 使用一次性方法
- 实现高效服务 减少内存占用
关键技术: Wanda(权重 × 激活)、SparseGPT(二阶)、结构化剪枝、N:M稀疏度
论文: Wanda ICLR 2024 (arXiv 2306.11695)、SparseGPT (arXiv 2301.00774)
安装
# Wanda实现
git clone https://github.com/locuslab/wanda
cd wanda
pip install -r requirements.txt
# 可选: SparseGPT
git clone https://github.com/IST-DASLab/sparsegpt
cd sparsegpt
pip install -e .
# 依赖项
pip install torch transformers accelerate
快速开始
Wanda剪枝(一次性,无需重新训练)
来源: ICLR 2024 (arXiv 2306.11695)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 校准数据(用于激活统计的小型数据集)
calib_data = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming the world.",
"Artificial intelligence powers modern applications.",
]
# Wanda剪枝函数
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda: 通过权重幅度 × 输入激活进行剪枝。
参数:
sparsity: 要剪枝的权重比例(0.5 = 50%)
"""
# 1. 收集激活统计
activations = {}
def hook_fn(name):
def hook(module, input, output):
# 存储输入激活范数
activations[name] = input[0].detach().abs().mean(dim=0)
return hook
# 为所有线性层注册钩子
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(hook_fn(name)))
# 运行校准数据
model.eval()
with torch.no_grad():
for text in calib_data:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
model(**inputs)
# 移除钩子
for hook in hooks:
hook.remove()
# 2. 基于 |权重| × 激活剪枝权重
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in activations:
W = module.weight.data
act = activations[name]
# 计算重要性: |权重| × 激活
importance = W.abs() * act.unsqueeze(0)
# 展平并找到阈值
threshold = torch.quantile(importance.flatten(), sparsity)
# 创建掩码
mask = importance >= threshold
# 应用掩码(剪枝)
W *= mask.float()
return model
# 应用Wanda剪枝(50%稀疏度,一次性,无需重新训练)
pruned_model = wanda_prune(model, calib_data, sparsity=0.5)
# 保存
pruned_model.save_pretrained("./llama-2-7b-wanda-50")
SparseGPT(二阶剪枝)
来源: arXiv 2301.00774
from sparsegpt import SparseGPT
# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# 初始化SparseGPT
pruner = SparseGPT(model)
# 校准数据
calib_data = load_calibration_data() # 约128个样本
# 剪枝(一次性,逐层重建)
pruned_model = pruner.prune(
calib_data=calib_data,
sparsity=0.5, # 50%稀疏度
prunen=0, # 非结构化(0)或N:M结构化
prunem=0,
percdamp=0.01, # Hessian逆的阻尼
)
# 结果: 50%稀疏度下近乎无损的剪枝
N:M结构化剪枝(硬件加速器)
def nm_prune(weight, n=2, m=4):
"""
N:M剪枝: 每M个连续权重中保留N个权重。
示例: 2:4 = 每4个权重中保留2个。
与NVIDIA稀疏张量核心兼容(2:4, 4:8)。
"""
# 将权重重塑为M组
shape = weight.shape
weight_flat = weight.flatten()
# 填充到M的倍数
pad_size = (m - weight_flat.numel() % m) % m
weight_padded = F.pad(weight_flat, (0, pad_size))
# 重塑为 (组数, m)
weight_grouped = weight_padded.reshape(-1, m)
# 在每组中找到前N个
_, indices = torch.topk(weight_grouped.abs(), n, dim=-1)
# 创建掩码
mask = torch.zeros_like(weight_grouped)
mask.scatter_(1, indices, 1.0)
# 应用掩码
weight_pruned = weight_grouped * mask
# 重塑回来
weight_pruned = weight_pruned.flatten()[:weight_flat.numel()]
return weight_pruned.reshape(shape)
# 应用2:4稀疏度(NVIDIA硬件)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.weight.data = nm_prune(module.weight.data, n=2, m=4)
# 50%稀疏度,在A100上使用稀疏张量核心实现2倍加速
核心概念
1. 剪枝标准
幅度剪枝(基线):
# 剪枝绝对值最小的权重
importance = weight.abs()
threshold = torch.quantile(importance, sparsity)
mask = importance >= threshold
Wanda(权重 × 激活):
# 重要性 = |权重| × 输入激活
importance = weight.abs() * activation
# 比单独使用幅度更好(考虑使用情况)
SparseGPT(二阶):
# 使用Hessian(二阶导数)计算重要性
# 更准确但计算成本高
importance = weight^2 / diag(Hessian)
2. 结构化与非结构化
非结构化(细粒度):
- 剪枝单个权重
- 更高质量(更好精度)
- 无硬件加速(不规则稀疏度)
结构化(粗粒度):
- 剪枝整个神经元、头或层
- 较低质量(更多精度损失)
- 硬件加速(规则稀疏度)
半结构化 (N:M):
- 两全其美
- 50%稀疏度(2:4)→ 在NVIDIA GPU上实现2倍加速
- 最小精度损失
3. 稀疏度模式
# 非结构化(随机)
# [1, 0, 1, 0, 1, 1, 0, 0]
# 优点: 灵活、高质量
# 缺点: 无加速
# 结构化(块)
# [1, 1, 0, 0, 1, 1, 0, 0]
# 优点: 硬件友好
# 缺点: 更多精度损失
# N:M(半结构化)
# [1, 0, 1, 0] [1, 1, 0, 0] (2:4模式)
# 优点: 硬件加速 + 良好质量
# 缺点: 需要特定硬件(NVIDIA)
剪枝策略
策略1: 渐进幅度剪枝
def gradual_prune(model, initial_sparsity=0.0, final_sparsity=0.5, num_steps=100):
"""在训练期间逐渐增加稀疏度。"""
for step in range(num_steps):
# 当前稀疏度
current_sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * (step / num_steps)
# 以当前稀疏度剪枝
for module in model.modules():
if isinstance(module, torch.nn.Linear):
weight = module.weight.data
threshold = torch.quantile(weight.abs().flatten(), current_sparsity)
mask = weight.abs() >= threshold
weight *= mask.float()
# 训练一步
train_step(model)
return model
策略2: 逐层剪枝
def layer_wise_prune(model, sparsity_per_layer):
"""不同层使用不同稀疏度。"""
# 早期层: 较少剪枝(更重要)
# 后期层: 较多剪枝(较不关键)
sparsity_schedule = {
"layer.0": 0.3, # 30%稀疏度
"layer.1": 0.4,
"layer.2": 0.5,
"layer.3": 0.6, # 60%稀疏度
}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# 找到层索引
for layer_name, sparsity in sparsity_schedule.items():
if layer_name in name:
# 以层特定稀疏度剪枝
prune_layer(module, sparsity)
break
return model
策略3: 迭代剪枝 + 微调
def iterative_prune_finetune(model, target_sparsity=0.5, iterations=5):
"""逐渐剪枝,在迭代之间进行微调。"""
current_sparsity = 0.0
sparsity_increment = target_sparsity / iterations
for i in range(iterations):
# 增加稀疏度
current_sparsity += sparsity_increment
# 剪枝
prune_model(model, sparsity=current_sparsity)
# 微调(恢复精度)
fine_tune(model, epochs=2, lr=1e-5)
return model
# 结果: 在高稀疏度下比一次性剪枝精度更好
生产部署
完整剪枝流程
from transformers import Trainer, TrainingArguments
def production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda", # 或 "sparsegpt"
):
# 1. 加载模型
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. 加载校准数据
calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")
# 3. 应用剪枝
if method == "wanda":
pruned_model = wanda_prune(model, calib_dataset, sparsity=target_sparsity)
elif method == "sparsegpt":
pruner = SparseGPT(model)
pruned_model = pruner.prune(calib_dataset, sparsity=target_sparsity)
# 4. (可选)微调以恢复精度
training_args = TrainingArguments(
output_dir="./pruned-model",
num_train_epochs=1,
per_device_train_batch_size=4,
learning_rate=1e-5,
bf16=True,
)
trainer = Trainer(
model=pruned_model,
args=training_args,
train_dataset=finetune_dataset,
)
trainer.train()
# 5. 保存
pruned_model.save_pretrained("./pruned-llama-7b-50")
tokenizer.save_pretrained("./pruned-llama-7b-50")
return pruned_model
# 使用
pruned_model = production_pruning_pipeline(
model_name="meta-llama/Llama-2-7b-hf",
target_sparsity=0.5,
method="wanda"
)
评估
from lm_eval import evaluator
# 评估剪枝与原始模型
original_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=meta-llama/Llama-2-7b-hf",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
pruned_results = evaluator.simple_evaluate(
model="hf",
model_args="pretrained=./pruned-llama-7b-50",
tasks=["arc_easy", "hellaswag", "winogrande"],
)
# 比较
print(f"原始: {original_results['results']['arc_easy']['acc']:.3f}")
print(f"剪枝: {pruned_results['results']['arc_easy']['acc']:.3f}")
print(f"退化: {(original_results - pruned_results):.3f}")
# 50%稀疏度下的典型结果:
# - Wanda: 精度损失小于1%
# - SparseGPT: 精度损失小于0.5%
# - 幅度剪枝: 精度损失2-3%
最佳实践
1. 稀疏度选择
# 保守(安全)
sparsity = 0.3 # 30%,精度损失小于0.5%
# 平衡(推荐)
sparsity = 0.5 # 50%,精度损失约1%
# 激进(有风险)
sparsity = 0.7 # 70%,精度损失2-5%
# 极端(模型依赖)
sparsity = 0.9 # 90%,显著退化
2. 方法选择
# 一次性,无需重新训练 → Wanda或SparseGPT
if no_retraining_budget:
use_method = "wanda" # 更快
# 最佳质量 → SparseGPT
if need_best_quality:
use_method = "sparsegpt" # 更准确
# 硬件加速 → N:M结构化
if need_speedup:
use_method = "nm_prune" # 2:4或4:8
3. 避免常见陷阱
# ❌ 错误: 不使用校准数据进行剪枝
prune_random(model) # 无激活统计
# ✅ 良好: 使用校准数据
prune_wanda(model, calib_data)
# ❌ 错误: 一次性剪枝稀疏度过高
prune(model, sparsity=0.9) # 大量精度损失
# ✅ 良好: 渐进或迭代
iterative_prune(model, target=0.9, steps=10)
性能比较
50%稀疏度下的剪枝方法(LLaMA-7B):
| 方法 | 精度损失 | 速度 | 内存 | 是否需要重新训练 |
|---|---|---|---|---|
| 幅度剪枝 | -2.5% | 1.0× | -50% | 否 |
| Wanda | -0.8% | 1.0× | -50% | 否 |
| SparseGPT | -0.4% | 1.0× | -50% | 否 |
| N:M (2:4) | -1.0% | 2.0× | -50% | 否 |
| 结构化 | -3.0% | 2.0× | -50% | 否 |
来源: Wanda论文(ICLR 2024)、SparseGPT论文
资源
- Wanda论文(ICLR 2024): https://arxiv.org/abs/2306.11695
- Wanda GitHub: https://github.com/locuslab/wanda
- SparseGPT论文: https://arxiv.org/abs/2301.00774
- SparseGPT GitHub: https://github.com/IST-DASLab/sparsegpt
- NVIDIA稀疏张量核心: https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/