名称: transformer-lens-interpretability 描述: 提供使用TransformerLens进行机制解释性研究的指南,通过钩子点和激活缓存来检查和操作Transformer内部。当需要逆向工程模型算法、研究注意力模式或执行激活修补实验时使用。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [机制解释性, TransformerLens, 激活修补, 电路分析] 依赖: [transformer-lens>=2.0.0, torch>=2.0.0]
TransformerLens: Transformers的机制解释性
TransformerLens 是用于GPT风格语言模型机制解释性研究的事实上的标准库。由Neel Nanda创建,Bryce Meyer维护,它提供了干净的接口,通过每个激活上的钩子点来检查和操作模型内部。
GitHub: TransformerLensOrg/TransformerLens (2,900+ 星)
何时使用TransformerLens
当您需要时使用TransformerLens:
- 逆向工程训练期间学习的算法
- 执行激活修补/因果追踪实验
- 研究注意力模式和信息流
- 分析电路(例如,归纳头、IOI电路)
- 缓存和检查中间激活
- 应用直接logit归因
考虑替代方案当:
- 您需要处理非Transformer架构 → 使用 nnsight 或 pyvene
- 您想训练/分析稀疏自编码器 → 使用 SAELens
- 您需要在大型模型上远程执行 → 使用带有NDIF的 nnsight
- 您想要更高级的因果干预抽象 → 使用 pyvene
安装
pip install transformer-lens
对于开发版本:
pip install git+https://github.com/TransformerLensOrg/TransformerLens
核心概念
HookedTransformer
主要的类,包装Transformer模型,在每个激活上都有钩子点:
from transformer_lens import HookedTransformer
# 加载一个模型
model = HookedTransformer.from_pretrained("gpt2-small")
# 对于门控模型(LLaMA、Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
支持模型(50+)
| 家族 | 模型 |
|---|---|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
| EleutherAI | pythia-70m 到 pythia-12b, gpt-neo, gpt-j-6b |
| Mistral | mistral-7b, mixtral-8x7b |
| 其他 | phi, qwen, opt, gemma |
激活缓存
运行模型并缓存所有中间激活:
# 获取所有激活
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)
# 访问特定激活
residual = cache["resid_post", 5] # 第5层残差流
attn_pattern = cache["pattern", 3] # 第3层注意力模式
mlp_out = cache["mlp_out", 7] # 第7层MLP输出
# 过滤要缓存的激活(节省内存)
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)
ActivationCache 键
| 键模式 | 形状 | 描述 |
|---|---|---|
resid_pre, layer |
[batch, pos, d_model] | 注意力前的残差 |
resid_mid, layer |
[batch, pos, d_model] | 注意力后的残差 |
resid_post, layer |
[batch, pos, d_model] | MLP后的残差 |
attn_out, layer |
[batch, pos, d_model] | 注意力输出 |
mlp_out, layer |
[batch, pos, d_model] | MLP输出 |
pattern, layer |
[batch, head, q_pos, k_pos] | 注意力模式(后softmax) |
q, layer |
[batch, pos, head, d_head] | 查询向量 |
k, layer |
[batch, pos, head, d_head] | 键向量 |
v, layer |
[batch, pos, head, d_head] | 值向量 |
工作流1: 激活修补(因果追踪)
通过将干净激活修补到损坏运行中,识别哪些激活因果影响模型输出。
步骤
from transformer_lens import HookedTransformer, patching
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# 1. 定义干净和损坏提示
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
# 2. 获取干净激活
_, clean_cache = model.run_with_cache(clean_tokens)
# 3. 定义指标(例如,logit差异)
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def metric(logits):
return logits[0, -1, paris_token] - logits[0, -1, rome_token]
# 4. 修补每个位置和层
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers):
for pos in range(clean_tokens.shape[1]):
def patch_hook(activation, hook):
activation[0, pos] = clean_cache[hook.name][0, pos]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)
# 5. 可视化结果(层 x 位置热图)
清单
- [ ] 定义最小差异的干净和损坏输入
- [ ] 选择捕捉行为差异的指标
- [ ] 缓存干净激活
- [ ] 系统修补每个(层,位置)组合
- [ ] 将结果可视化为热图
- [ ] 识别因果热点
工作流2: 电路分析(间接对象识别)
复制“Interpretability in the Wild”中的IOI电路发现。
步骤
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# IOI任务:"When John and Mary went to the store, Mary gave a bottle to"
# 模型应预测"John"(间接对象)
prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)
# 1. 获取基准logits
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")
# 2. 计算logit差异(IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit差异: {logit_diff.item():.3f}")
# 3. 通过头直接logit归因
def get_head_contribution(layer, head):
# 将头输出投影到logits
head_out = cache["z", layer][0, :, head, :] # [pos, d_head]
W_O = model.W_O[layer, head] # [d_head, d_model]
W_U = model.W_U # [d_model, vocab]
# 头对最终位置logits的贡献
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]
# 4. 映射所有头
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
head_contributions[layer, head] = get_head_contribution(layer, head)
# 5. 识别顶级贡献头(名称移动器、备份名称移动器)
清单
- [ ] 设置具有清晰IO/S令牌的任务
- [ ] 计算基准logit差异
- [ ] 按注意力头贡献分解
- [ ] 识别关键电路组件(名称移动器、S抑制、归纳)
- [ ] 用消融实验验证
工作流3: 归纳头检测
找到实现[A][B]…[A] → [B]模式的归纳头。
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# 创建重复序列:[A][B][A] 应预测 [B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # 任意令牌
_, cache = model.run_with_cache(repeated_tokens)
# 归纳头从最后一个[A]关注到第一个[B]
# 检查从位置2到位置1的注意力
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
pattern = cache["pattern", layer][0] # [head, q_pos, k_pos]
# 从位置2到位置1的注意力
induction_scores[layer] = pattern[:, 2, 1]
# 高分头是归纳头
top_heads = torch.topk(induction_scores.flatten(), k=5)
常见问题与解决方案
问题: 调试后钩子持久
# 错误: 旧钩子仍活跃
model.run_with_hooks(tokens, fwd_hooks=[...]) # 调试,添加新钩子
model.run_with_hooks(tokens, fwd_hooks=[...]) # 旧钩子还在!
# 正确: 始终重置钩子
model.reset_hooks()
model.run_with_hooks(tokens, fwd_hooks=[...])
问题: 令牌化陷阱
# 错误: 假设一致令牌化
model.to_tokens("Tim") # 单个令牌
model.to_tokens("Neel") # 变成"Ne" + "el"(两个令牌!)
# 正确: 显式检查令牌化
tokens = model.to_tokens("Neel", prepend_bos=False)
print(model.to_str_tokens(tokens)) # ['Ne', 'el']
问题: 分析中忽略LayerNorm
# 错误: 忽略LayerNorm
pre_activation = residual @ model.W_in[layer]
# 正确: 包含LayerNorm
ln_scale = model.blocks[layer].ln2.w
ln_out = model.blocks[layer].ln2(residual)
pre_activation = ln_out @ model.W_in[layer]
问题: 大型模型内存爆炸
# 使用选择性缓存
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda n: "resid_post" in n or "pattern" in n,
device="cpu" # 在CPU上缓存
)
关键类参考
| 类 | 目的 |
|---|---|
HookedTransformer |
带有钩子的主模型包装器 |
ActivationCache |
类似字典的激活缓存 |
HookedTransformerConfig |
模型配置 |
FactoredMatrix |
高效因式矩阵操作 |
与SAELens集成
TransformerLens与SAELens集成,用于稀疏自编码器分析:
from transformer_lens import HookedTransformer
from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
# 用SAE运行
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
sae_acts = sae.encode(cache["resid_pre", 8])
参考文档
有关详细API文档、教程和高级用法,请参阅references/文件夹:
| 文件 | 内容 |
|---|---|
| references/README.md | 概述和快速入门指南 |
| references/api.md | HookedTransformer、ActivationCache、HookPoints的完整API参考 |
| references/tutorials.md | 激活修补、电路分析、logit镜头的逐步教程 |
外部资源
教程
- 主演示笔记本
- 激活修补演示
- ARENA机械解释课程 - 200+小时教程
论文
官方文档
版本说明
- v2.0: 移除了HookedSAE(移至SAELens)
- v3.0 (alpha): TransformerBridge用于加载任何nn.Module