推测解码Skill speculative-decoding

这个技能用于加速大型语言模型(LLM)的推理过程,通过推测解码、Medusa多头和前瞻解码等技术,实现1.5到3.6倍的推理速度提升,降低延迟,适用于实时应用如聊天机器人、代码生成等。关键词包括:LLM推理加速、推测解码、Medusa、前瞻解码、并行生成、延迟优化、AI推理优化。

AI应用 0 次安装 0 次浏览 更新于 3/21/2026

名称: 推测解码 描述: 使用推测解码、Medusa多头和前瞻解码技术加速LLM推理。适用于优化推理速度(1.5-3.6倍加速)、降低实时应用延迟或部署计算资源有限的模型。涵盖草案模型、树状注意力、Jacobi迭代、并行令牌生成和生产部署策略。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [新兴技术, 推测解码, Medusa, 前瞻解码, 快速推理, 草案模型, 树注意力, 并行生成, 延迟降低, 推理优化] 依赖项: [transformers, torch]

推测解码:加速LLM推理

何时使用此技能

在以下情况使用推测解码:

  • 加速推理 1.5-3.6倍,无质量损失
  • 降低延迟 适用于实时应用(聊天机器人、代码生成)
  • 优化吞吐量 用于高流量服务
  • 高效部署 在有限硬件上
  • 更快生成 无需更改模型架构

关键技术:草案模型推测解码、Medusa(多头)、前瞻解码(Jacobi迭代)

论文:Medusa(arXiv 2401.10774)、前瞻解码(ICML 2024)、推测解码综述(ACL 2024)

安装

# 标准推测解码(transformers)
pip install transformers accelerate

# Medusa(多头解码)
git clone https://github.com/FasterDecoding/Medusa
cd Medusa
pip install -e .

# 前瞻解码
git clone https://github.com/hao-ai-lab/LookaheadDecoding
cd LookaheadDecoding
pip install -e .

# 可选:vLLM 带推测解码
pip install vllm

快速入门

基础推测解码(草案模型)

from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载目标模型(大、慢)
target_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    device_map="auto",
    torch_dtype=torch.float16
)

# 加载草案模型(小、快)
draft_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    device_map="auto",
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")

# 使用推测解码生成
prompt = "用简单术语解释量子计算:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Transformers 4.36+ 支持辅助生成
outputs = target_model.generate(
    **inputs,
    assistant_model=draft_model,  # 启用推测解码
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

Medusa(多头解码)

from medusa.model.medusa_model import MedusaModel

# 加载Medusa增强模型
model = MedusaModel.from_pretrained(
    "FasterDecoding/medusa-vicuna-7b-v1.3",  # 预训练带Medusa头
    torch_dtype=torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")

# 使用Medusa生成(2-3倍加速)
prompt = "写一个Python函数计算斐波那契数列:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.medusa_generate(
    **inputs,
    max_new_tokens=256,
    temperature=0.7,
    posterior_threshold=0.09,  # 接受阈值
    posterior_alpha=0.3,       # 树构建参数
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

前瞻解码(Jacobi迭代)

from lookahead.lookahead_decoding import LookaheadDecoding

# 加载模型
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# 初始化前瞻解码
lookahead = LookaheadDecoding(
    model=model,
    tokenizer=tokenizer,
    window_size=15,    # 前瞻窗口(W)
    ngram_size=5,      # N-gram大小(N)
    guess_size=5       # 并行猜测数
)

# 生成(1.5-2.3倍加速)
prompt = "用Python实现快速排序:"
output = lookahead.generate(prompt, max_new_tokens=256)
print(output)

核心概念

1. 推测解码(草案模型)

想法:用小草案模型生成候选,大目标模型并行验证。

算法

  1. 草案模型推测生成K个令牌
  2. 目标模型并行评估所有K个令牌(单次前向传播)
  3. 接受草案和目标一致的令牌
  4. 拒绝首个不一致,从那里继续
def speculative_decode(target_model, draft_model, prompt, K=4):
    """推测解码算法。"""
    # 1. 生成K个草案令牌
    draft_tokens = draft_model.generate(prompt, max_new_tokens=K)

    # 2. 目标模型在一次前向传播中评估所有K个令牌
    target_logits = target_model(draft_tokens)  # 并行!

    # 3. 基于概率匹配接受/拒绝
    accepted = []
    for i in range(K):
        p_draft = softmax(draft_model.logits[i])
        p_target = softmax(target_logits[i])

        # 接受概率
        if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
            accepted.append(draft_tokens[i])
        else:
            break  # 拒绝,从目标重新采样

    return accepted

性能

  • 加速:1.5-2倍,草案模型好时
  • 零质量损失(数学等价于目标模型)
  • 最佳时草案模型比目标小5-10倍

2. Medusa(多头解码)

来源:arXiv 2401.10774(2024)

创新:为现有模型添加多个预测头,预测未来令牌无需单独草案模型。

架构

输入 → 基础LLM(冻结) → 隐藏状态
                                ├→ 头1(预测令牌t+1)
                                ├→ 头2(预测令牌t+2)
                                ├→ 头3(预测令牌t+3)
                                └→ 头4(预测令牌t+4)

训练

  • Medusa-1:冻结基础LLM,仅训练头
    • 2.2倍加速,无损
  • Medusa-2:微调基础LLM和头一起
    • 2.3-3.6倍加速,更好质量

树状注意力

# Medusa构建候选树
# 示例:预测2步前,每步前2个

#         根
#        /    \
#      T1a    T1b  (步骤1:2个候选)
#     /  \    / \
#  T2a  T2b T2c T2d  (步骤2:总4个候选)

# 单次前向传播评估整个树!

优势

  • 无需单独草案模型
  • 最小训练(仅头)
  • 兼容任何LLM

3. 前瞻解码(Jacobi迭代)

来源:ICML 2024

核心想法:将自回归解码重新表述为求解方程组,使用Jacobi迭代并行求解。

数学公式

传统:  y_t = f(x, y_1, ..., y_{t-1})  (顺序)
Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)})  (并行)

两个分支

  1. 前瞻分支:并行生成n-gram

    • 窗口大小W:向前看多少步
    • N-gram大小N:使用多少过去令牌
  2. 验证分支:验证有前景的n-gram

    • 匹配n-gram与生成令牌
    • 如果首令牌匹配则接受
class LookaheadDecoding:
    def __init__(self, model, window_size=15, ngram_size=5):
        self.model = model
        self.W = window_size  # 前瞻窗口
        self.N = ngram_size   # N-gram大小

    def generate_step(self, tokens):
        # 前瞻分支:生成 W × N 候选
        candidates = {}
        for w in range(1, self.W + 1):
            for n in range(1, self.N + 1):
                # 从位置w开始生成长度为n的n-gram
                ngram = self.generate_ngram(tokens, start=w, length=n)
                candidates[(w, n)] = ngram

        # 验证分支:查找匹配的n-gram
        verified = []
        for ngram in candidates.values():
            if ngram[0] == tokens[-1]:  # 首令牌匹配最后输入
                if self.verify(tokens, ngram):
                    verified.append(ngram)

        # 接受最长验证的n-gram
        return max(verified, key=len) if verified else [self.model.generate_next(tokens)]

性能

  • 加速:1.5-2.3倍(代码生成可达3.6倍)
  • 无需草案模型或训练
  • 开箱即用,兼容任何模型

方法比较

方法 加速 需要训练 草案模型 质量损失
草案模型推测 1.5-2倍 是(外部)
Medusa 2-3.6倍 最小(仅头) 否(内置头)
前瞻解码 1.5-2.3倍
朴素批处理 1.2-1.5倍

高级模式

训练Medusa头

from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import initialize_past_key_values
import torch.nn as nn

# 1. 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
    "lmsys/vicuna-7b-v1.3",
    torch_dtype=torch.float16
)

# 2. 添加Medusa头
num_heads = 4
medusa_heads = nn.ModuleList([
    nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)
    for _ in range(num_heads)
])

# 3. 训练循环(冻结基础模型为Medusa-1)
for param in base_model.parameters():
    param.requires_grad = False  # 冻结基础

optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)

for batch in dataloader:
    # 前向传播
    hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]

    # 用每个头预测未来令牌
    loss = 0
    for i, head in enumerate(medusa_heads):
        logits = head(hidden_states)
        # 目标:令牌偏移 (i+1) 位置
        target = batch['input_ids'][:, i+1:]
        loss += F.cross_entropy(logits[:, :-i-1], target)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

混合:推测解码 + Medusa

# 使用Medusa作为推测解码的草案模型
draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b")
target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")

# 草案用Medusa生成多个候选
draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)

# 目标在单次前向传播中验证
outputs = target_model.generate(
    prompt,
    assistant_model=draft_medusa,  # 使用Medusa作为草案
    max_new_tokens=256
)

# 结合优势:Medusa速度 + 大模型质量

最优草案模型选择

def select_draft_model(target_model_size, target):
    """为推测解码选择最优草案模型。"""
    # 规则:草案应比目标小5-10倍
    if target_model_size == "70B":
        return "7B"  # 10倍小
    elif target_model_size == "33B":
        return "7B"  # 5倍小
    elif target_model_size == "13B":
        return "1B"  # 13倍小
    else:
        return None  # 目标太小,改用Medusa/前瞻解码

# 示例
draft = select_draft_model("70B", target_model)
# 返回 "7B" → 使用Llama-2-7b作为Llama-2-70b的草案

最佳实践

1. 选择正确方法

# 新部署 → Medusa(总体加速最佳,无草案模型)
if deploying_new_model:
    use_method = "Medusa"

# 现有部署有小模型可用 → 草案推测
elif have_small_version_of_model:
    use_method = "Draft Model Speculative"

# 想要零训练/设置 → 前瞻解码
elif want_plug_and_play:
    use_method = "Lookahead Decoding"

2. 超参数调优

草案模型推测

# K = 推测令牌数
K = 4  # 好默认
K = 2  # 保守(更高接受率)
K = 8  # 激进(接受率低,但接受时更多)

# 规则:更大K → 更多加速,如果草案模型好

Medusa

# 后验阈值(接受置信度)
posterior_threshold = 0.09  # 标准(来自论文)
posterior_threshold = 0.05  # 更保守(较慢,更高质量)
posterior_threshold = 0.15  # 更激进(更快,可能降低质量)

# 树深度(向前看多少步)
medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]]  # 深度3(标准)

前瞻解码

# 窗口大小 W(前瞻距离)
# N-gram大小 N(生成上下文)

# 7B模型(更多资源)
W, N = 15, 5

# 13B模型(适中)
W, N = 10, 5

# 33B+模型(有限资源)
W, N = 7, 5

3. 生产部署

# vLLM 带推测解码
from vllm import LLM, SamplingParams

# 用草案模型初始化
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    speculative_model="meta-llama/Llama-2-7b-hf",  # 草案模型
    num_speculative_tokens=5,
    use_v2_block_manager=True,
)

# 生成
prompts = ["告诉我关于AI:", "解释量子物理:"]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(output.outputs[0].text)

资源

另请参阅

  • references/draft_model.md - 草案模型选择和训练
  • references/medusa.md - Medusa架构和训练
  • references/lookahead.md - 前瞻解码实现细节