name: mamba-architecture description: 状态空间模型,复杂度为O(n),而Transformer为O(n²)。推理速度快5倍,支持百万令牌序列,无需KV缓存。选择性SSM,硬件感知设计。Mamba-1 (d_state=16) 和 Mamba-2 (d_state=128, 多头部)。模型130M-2.8B在HuggingFace上可用。 version: 1.0.0 author: Orchestra Research license: MIT tags: [模型架构, Mamba, 状态空间模型, SSM, 线性复杂度, 长上下文, 高效推理, 硬件感知, Transformer替代品] dependencies: [mamba-ssm, torch, transformers, causal-conv1d]
Mamba - 选择性状态空间模型
快速开始
Mamba是一种状态空间模型架构,实现序列建模的O(n)线性复杂度。
安装:
# 安装 causal-conv1d(可选,用于效率)
pip install causal-conv1d>=1.4.0
# 安装 Mamba
pip install mamba-ssm
# 或两者一起安装
pip install mamba-ssm[causal-conv1d]
先决条件:Linux, NVIDIA GPU, PyTorch 1.12+, CUDA 11.6+
基本用法(Mamba块):
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
d_model=dim, # 模型维度
d_state=16, # SSM状态维度
d_conv=4, # Conv1d核大小
expand=2 # 扩展因子
).to("cuda")
y = model(x) # O(n)复杂度!
assert y.shape == x.shape
常见工作流
工作流1:使用Mamba-2的语言模型
完整的LM与生成:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
import torch
# 配置Mamba-2 LM
config = MambaConfig(
d_model=1024, # 隐藏维度
n_layer=24, # 层数
vocab_size=50277, # 词汇表大小
ssm_cfg=dict(
layer="Mamba2", # 使用Mamba-2
d_state=128, # Mamba-2的更大状态
headdim=64, # 头部维度
ngroups=1 # 组数
)
)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
# 生成文本
input_ids = torch.randint(0, 1000, (1, 20), device="cuda", dtype=torch.long)
output = model.generate(
input_ids=input_ids,
max_length=100,
temperature=0.7,
top_p=0.9
)
工作流2:使用预训练的Mamba模型
从HuggingFace加载:
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# 加载预训练模型
model_name = "state-spaces/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") # 使用兼容的tokenizer
model = MambaLMHeadModel.from_pretrained(model_name, device="cuda", dtype=torch.float16)
# 生成
prompt = "The future of AI is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
output_ids = model.generate(
input_ids=input_ids,
max_length=200,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2
)
generated_text = tokenizer.decode(output_ids[0])
print(generated_text)
可用模型:
state-spaces/mamba-130mstate-spaces/mamba-370mstate-spaces/mamba-790mstate-spaces/mamba-1.4bstate-spaces/mamba-2.8b
工作流3:Mamba-1 vs Mamba-2
Mamba-1(较小状态):
from mamba_ssm import Mamba
model = Mamba(
d_model=256,
d_state=16, # 较小状态维度
d_conv=4,
expand=2
).to("cuda")
Mamba-2(多头部,较大状态):
from mamba_ssm import Mamba2
model = Mamba2(
d_model=256,
d_state=128, # 较大状态维度
d_conv=4,
expand=2,
headdim=64, # 多头部的头部维度
ngroups=1 # 并行组
).to("cuda")
关键区别:
- 状态大小:Mamba-1 (d_state=16) vs Mamba-2 (d_state=128)
- 架构:Mamba-2具有多头部结构
- 归一化:Mamba-2使用RMSNorm
- 分布式:Mamba-2支持张量并行
工作流4:基准测试 vs Transformers
生成速度比较:
# 基准测试Mamba
python benchmarks/benchmark_generation_mamba_simple.py \
--model-name "state-spaces/mamba-2.8b" \
--prompt "The future of machine learning is" \
--topp 0.9 --temperature 0.7 --repetition-penalty 1.2
# 基准测试Transformer
python benchmarks/benchmark_generation_mamba_simple.py \
--model-name "EleutherAI/pythia-2.8b" \
--prompt "The future of machine learning is" \
--topp 0.9 --temperature 0.7 --repetition-penalty 1.2
预期结果:
- Mamba:推理速度快5倍
- 内存:无需KV缓存
- 扩展性:与序列长度线性相关
何时使用 vs 替代方案
使用Mamba当:
- 需要长序列(100K+令牌)
- 希望推理速度比Transformer快
- 内存受限(无KV缓存)
- 构建流式应用
- 线性扩展重要
优势:
- O(n)复杂度:线性 vs 二次
- 5倍更快推理:无注意力开销
- 无KV缓存:内存使用低
- 百万令牌序列:硬件高效
- 流式:每个令牌恒定内存
使用替代方案代替:
- Transformers:需要最佳性能,有计算资源
- RWKV:想要RNN+Transformer混合
- RetNet:需要基于保留的架构
- Hyena:想要基于卷积的方法
常见问题
问题:CUDA内存不足
减小批大小或使用梯度检查点:
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
model.gradient_checkpointing_enable() # 启用检查点
问题:安装慢
安装二进制轮子(非源码):
pip install mamba-ssm --no-build-isolation
问题:缺少causal-conv1d
单独安装:
pip install causal-conv1d>=1.4.0
问题:模型无法从HuggingFace加载
使用MambaLMHeadModel.from_pretrained(非AutoModel):
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
高级主题
选择性SSM:参见references/selective-ssm.md了解数学公式、状态空间方程以及选择性如何实现O(n)复杂度。
Mamba-2架构:参见references/mamba2-details.md了解多头部结构、张量并行和分布式训练设置。
性能优化:参见references/performance.md了解硬件感知设计、CUDA内核和内存效率技术。
硬件要求
- GPU:NVIDIA with CUDA 11.6+
- VRAM:
- 130M模型:2GB
- 370M模型:4GB
- 790M模型:8GB
- 1.4B模型:14GB
- 2.8B模型:28GB (FP16)
- 推理:比Transformer快5倍
- 内存:无KV缓存(比Transformer低)
性能(vs Transformers):
- 速度:推理速度快5倍
- 内存:减少50%(无KV缓存)
- 扩展性:线性 vs 二次
资源
- 论文 (Mamba-1): https://arxiv.org/abs/2312.00752 (Dec 2023)
- 论文 (Mamba-2): https://arxiv.org/abs/2405.21060 (May 2024)
- GitHub: https://github.com/state-spaces/mamba ⭐ 13,000+
- 模型: https://huggingface.co/state-spaces
- 文档: 仓库README和wiki