名称: 推测解码 描述: 使用推测解码、Medusa多头和前瞻解码技术加速LLM推理。适用于优化推理速度(1.5-3.6倍加速)、降低实时应用延迟或部署计算资源有限的模型。涵盖草案模型、树状注意力、Jacobi迭代、并行令牌生成和生产部署策略。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [新兴技术, 推测解码, Medusa, 前瞻解码, 快速推理, 草案模型, 树注意力, 并行生成, 延迟降低, 推理优化] 依赖项: [transformers, torch]
推测解码:加速LLM推理
何时使用此技能
在以下情况使用推测解码:
- 加速推理 1.5-3.6倍,无质量损失
- 降低延迟 适用于实时应用(聊天机器人、代码生成)
- 优化吞吐量 用于高流量服务
- 高效部署 在有限硬件上
- 更快生成 无需更改模型架构
关键技术:草案模型推测解码、Medusa(多头)、前瞻解码(Jacobi迭代)
论文:Medusa(arXiv 2401.10774)、前瞻解码(ICML 2024)、推测解码综述(ACL 2024)
安装
# 标准推测解码(transformers)
pip install transformers accelerate
# Medusa(多头解码)
git clone https://github.com/FasterDecoding/Medusa
cd Medusa
pip install -e .
# 前瞻解码
git clone https://github.com/hao-ai-lab/LookaheadDecoding
cd LookaheadDecoding
pip install -e .
# 可选:vLLM 带推测解码
pip install vllm
快速入门
基础推测解码(草案模型)
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载目标模型(大、慢)
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
device_map="auto",
torch_dtype=torch.float16
)
# 加载草案模型(小、快)
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
# 使用推测解码生成
prompt = "用简单术语解释量子计算:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Transformers 4.36+ 支持辅助生成
outputs = target_model.generate(
**inputs,
assistant_model=draft_model, # 启用推测解码
max_new_tokens=256,
do_sample=True,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Medusa(多头解码)
from medusa.model.medusa_model import MedusaModel
# 加载Medusa增强模型
model = MedusaModel.from_pretrained(
"FasterDecoding/medusa-vicuna-7b-v1.3", # 预训练带Medusa头
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")
# 使用Medusa生成(2-3倍加速)
prompt = "写一个Python函数计算斐波那契数列:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.medusa_generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
posterior_threshold=0.09, # 接受阈值
posterior_alpha=0.3, # 树构建参数
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
前瞻解码(Jacobi迭代)
from lookahead.lookahead_decoding import LookaheadDecoding
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 初始化前瞻解码
lookahead = LookaheadDecoding(
model=model,
tokenizer=tokenizer,
window_size=15, # 前瞻窗口(W)
ngram_size=5, # N-gram大小(N)
guess_size=5 # 并行猜测数
)
# 生成(1.5-2.3倍加速)
prompt = "用Python实现快速排序:"
output = lookahead.generate(prompt, max_new_tokens=256)
print(output)
核心概念
1. 推测解码(草案模型)
想法:用小草案模型生成候选,大目标模型并行验证。
算法:
- 草案模型推测生成K个令牌
- 目标模型并行评估所有K个令牌(单次前向传播)
- 接受草案和目标一致的令牌
- 拒绝首个不一致,从那里继续
def speculative_decode(target_model, draft_model, prompt, K=4):
"""推测解码算法。"""
# 1. 生成K个草案令牌
draft_tokens = draft_model.generate(prompt, max_new_tokens=K)
# 2. 目标模型在一次前向传播中评估所有K个令牌
target_logits = target_model(draft_tokens) # 并行!
# 3. 基于概率匹配接受/拒绝
accepted = []
for i in range(K):
p_draft = softmax(draft_model.logits[i])
p_target = softmax(target_logits[i])
# 接受概率
if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
accepted.append(draft_tokens[i])
else:
break # 拒绝,从目标重新采样
return accepted
性能:
- 加速:1.5-2倍,草案模型好时
- 零质量损失(数学等价于目标模型)
- 最佳时草案模型比目标小5-10倍
2. Medusa(多头解码)
来源:arXiv 2401.10774(2024)
创新:为现有模型添加多个预测头,预测未来令牌无需单独草案模型。
架构:
输入 → 基础LLM(冻结) → 隐藏状态
├→ 头1(预测令牌t+1)
├→ 头2(预测令牌t+2)
├→ 头3(预测令牌t+3)
└→ 头4(预测令牌t+4)
训练:
- Medusa-1:冻结基础LLM,仅训练头
- 2.2倍加速,无损
- Medusa-2:微调基础LLM和头一起
- 2.3-3.6倍加速,更好质量
树状注意力:
# Medusa构建候选树
# 示例:预测2步前,每步前2个
# 根
# / \
# T1a T1b (步骤1:2个候选)
# / \ / \
# T2a T2b T2c T2d (步骤2:总4个候选)
# 单次前向传播评估整个树!
优势:
- 无需单独草案模型
- 最小训练(仅头)
- 兼容任何LLM
3. 前瞻解码(Jacobi迭代)
来源:ICML 2024
核心想法:将自回归解码重新表述为求解方程组,使用Jacobi迭代并行求解。
数学公式:
传统: y_t = f(x, y_1, ..., y_{t-1}) (顺序)
Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)}) (并行)
两个分支:
-
前瞻分支:并行生成n-gram
- 窗口大小W:向前看多少步
- N-gram大小N:使用多少过去令牌
-
验证分支:验证有前景的n-gram
- 匹配n-gram与生成令牌
- 如果首令牌匹配则接受
class LookaheadDecoding:
def __init__(self, model, window_size=15, ngram_size=5):
self.model = model
self.W = window_size # 前瞻窗口
self.N = ngram_size # N-gram大小
def generate_step(self, tokens):
# 前瞻分支:生成 W × N 候选
candidates = {}
for w in range(1, self.W + 1):
for n in range(1, self.N + 1):
# 从位置w开始生成长度为n的n-gram
ngram = self.generate_ngram(tokens, start=w, length=n)
candidates[(w, n)] = ngram
# 验证分支:查找匹配的n-gram
verified = []
for ngram in candidates.values():
if ngram[0] == tokens[-1]: # 首令牌匹配最后输入
if self.verify(tokens, ngram):
verified.append(ngram)
# 接受最长验证的n-gram
return max(verified, key=len) if verified else [self.model.generate_next(tokens)]
性能:
- 加速:1.5-2.3倍(代码生成可达3.6倍)
- 无需草案模型或训练
- 开箱即用,兼容任何模型
方法比较
| 方法 | 加速 | 需要训练 | 草案模型 | 质量损失 |
|---|---|---|---|---|
| 草案模型推测 | 1.5-2倍 | 否 | 是(外部) | 无 |
| Medusa | 2-3.6倍 | 最小(仅头) | 否(内置头) | 无 |
| 前瞻解码 | 1.5-2.3倍 | 无 | 否 | 无 |
| 朴素批处理 | 1.2-1.5倍 | 无 | 否 | 无 |
高级模式
训练Medusa头
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import initialize_past_key_values
import torch.nn as nn
# 1. 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
"lmsys/vicuna-7b-v1.3",
torch_dtype=torch.float16
)
# 2. 添加Medusa头
num_heads = 4
medusa_heads = nn.ModuleList([
nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)
for _ in range(num_heads)
])
# 3. 训练循环(冻结基础模型为Medusa-1)
for param in base_model.parameters():
param.requires_grad = False # 冻结基础
optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)
for batch in dataloader:
# 前向传播
hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]
# 用每个头预测未来令牌
loss = 0
for i, head in enumerate(medusa_heads):
logits = head(hidden_states)
# 目标:令牌偏移 (i+1) 位置
target = batch['input_ids'][:, i+1:]
loss += F.cross_entropy(logits[:, :-i-1], target)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
混合:推测解码 + Medusa
# 使用Medusa作为推测解码的草案模型
draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b")
target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")
# 草案用Medusa生成多个候选
draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)
# 目标在单次前向传播中验证
outputs = target_model.generate(
prompt,
assistant_model=draft_medusa, # 使用Medusa作为草案
max_new_tokens=256
)
# 结合优势:Medusa速度 + 大模型质量
最优草案模型选择
def select_draft_model(target_model_size, target):
"""为推测解码选择最优草案模型。"""
# 规则:草案应比目标小5-10倍
if target_model_size == "70B":
return "7B" # 10倍小
elif target_model_size == "33B":
return "7B" # 5倍小
elif target_model_size == "13B":
return "1B" # 13倍小
else:
return None # 目标太小,改用Medusa/前瞻解码
# 示例
draft = select_draft_model("70B", target_model)
# 返回 "7B" → 使用Llama-2-7b作为Llama-2-70b的草案
最佳实践
1. 选择正确方法
# 新部署 → Medusa(总体加速最佳,无草案模型)
if deploying_new_model:
use_method = "Medusa"
# 现有部署有小模型可用 → 草案推测
elif have_small_version_of_model:
use_method = "Draft Model Speculative"
# 想要零训练/设置 → 前瞻解码
elif want_plug_and_play:
use_method = "Lookahead Decoding"
2. 超参数调优
草案模型推测:
# K = 推测令牌数
K = 4 # 好默认
K = 2 # 保守(更高接受率)
K = 8 # 激进(接受率低,但接受时更多)
# 规则:更大K → 更多加速,如果草案模型好
Medusa:
# 后验阈值(接受置信度)
posterior_threshold = 0.09 # 标准(来自论文)
posterior_threshold = 0.05 # 更保守(较慢,更高质量)
posterior_threshold = 0.15 # 更激进(更快,可能降低质量)
# 树深度(向前看多少步)
medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]] # 深度3(标准)
前瞻解码:
# 窗口大小 W(前瞻距离)
# N-gram大小 N(生成上下文)
# 7B模型(更多资源)
W, N = 15, 5
# 13B模型(适中)
W, N = 10, 5
# 33B+模型(有限资源)
W, N = 7, 5
3. 生产部署
# vLLM 带推测解码
from vllm import LLM, SamplingParams
# 用草案模型初始化
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
speculative_model="meta-llama/Llama-2-7b-hf", # 草案模型
num_speculative_tokens=5,
use_v2_block_manager=True,
)
# 生成
prompts = ["告诉我关于AI:", "解释量子物理:"]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
资源
- Medusa论文:https://arxiv.org/abs/2401.10774
- Medusa GitHub:https://github.com/FasterDecoding/Medusa
- 前瞻解码(ICML 2024):https://lmsys.org/blog/2023-11-21-lookahead-decoding/
- 前瞻解码 GitHub:https://github.com/hao-ai-lab/LookaheadDecoding
- 推测解码综述(ACL 2024):https://aclanthology.org/2024.findings-acl.456.pdf
- 综合综述:https://arxiv.org/abs/2401.07851
另请参阅
references/draft_model.md- 草案模型选择和训练references/medusa.md- Medusa架构和训练references/lookahead.md- 前瞻解码实现细节