TransformerLens机制解释性库Skill transformer-lens-interpretability

TransformerLens 是一个用于Transformer模型机制解释性研究的Python库,允许用户通过钩子点和激活缓存来检查和操作模型内部状态,支持激活修补、电路分析、注意力模式研究等功能,常用于AI模型的解释性分析,特别是在自然语言处理领域。关键词:机械解释性、Transformer模型、激活缓存、钩子点、电路分析、AI解释性、深度学习、NLP。

NLP 0 次安装 0 次浏览 更新于 3/21/2026

名称: transformer-lens-interpretability 描述: 提供使用TransformerLens进行机制解释性研究的指南,通过钩子点和激活缓存来检查和操作Transformer内部。当需要逆向工程模型算法、研究注意力模式或执行激活修补实验时使用。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [机制解释性, TransformerLens, 激活修补, 电路分析] 依赖: [transformer-lens>=2.0.0, torch>=2.0.0]

TransformerLens: Transformers的机制解释性

TransformerLens 是用于GPT风格语言模型机制解释性研究的事实上的标准库。由Neel Nanda创建,Bryce Meyer维护,它提供了干净的接口,通过每个激活上的钩子点来检查和操作模型内部。

GitHub: TransformerLensOrg/TransformerLens (2,900+ 星)

何时使用TransformerLens

当您需要时使用TransformerLens:

  • 逆向工程训练期间学习的算法
  • 执行激活修补/因果追踪实验
  • 研究注意力模式和信息流
  • 分析电路(例如,归纳头、IOI电路)
  • 缓存和检查中间激活
  • 应用直接logit归因

考虑替代方案当:

  • 您需要处理非Transformer架构 → 使用 nnsightpyvene
  • 您想训练/分析稀疏自编码器 → 使用 SAELens
  • 您需要在大型模型上远程执行 → 使用带有NDIF的 nnsight
  • 您想要更高级的因果干预抽象 → 使用 pyvene

安装

pip install transformer-lens

对于开发版本:

pip install git+https://github.com/TransformerLensOrg/TransformerLens

核心概念

HookedTransformer

主要的类,包装Transformer模型,在每个激活上都有钩子点:

from transformer_lens import HookedTransformer

# 加载一个模型
model = HookedTransformer.from_pretrained("gpt2-small")

# 对于门控模型(LLaMA、Mistral)
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")

支持模型(50+)

家族 模型
GPT-2 gpt2, gpt2-medium, gpt2-large, gpt2-xl
LLaMA llama-7b, llama-13b, llama-2-7b, llama-2-13b
EleutherAI pythia-70m 到 pythia-12b, gpt-neo, gpt-j-6b
Mistral mistral-7b, mixtral-8x7b
其他 phi, qwen, opt, gemma

激活缓存

运行模型并缓存所有中间激活:

# 获取所有激活
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)

# 访问特定激活
residual = cache["resid_post", 5]  # 第5层残差流
attn_pattern = cache["pattern", 3]  # 第3层注意力模式
mlp_out = cache["mlp_out", 7]  # 第7层MLP输出

# 过滤要缓存的激活(节省内存)
logits, cache = model.run_with_cache(
    tokens,
    names_filter=lambda name: "resid_post" in name
)

ActivationCache 键

键模式 形状 描述
resid_pre, layer [batch, pos, d_model] 注意力前的残差
resid_mid, layer [batch, pos, d_model] 注意力后的残差
resid_post, layer [batch, pos, d_model] MLP后的残差
attn_out, layer [batch, pos, d_model] 注意力输出
mlp_out, layer [batch, pos, d_model] MLP输出
pattern, layer [batch, head, q_pos, k_pos] 注意力模式(后softmax)
q, layer [batch, pos, head, d_head] 查询向量
k, layer [batch, pos, head, d_head] 键向量
v, layer [batch, pos, head, d_head] 值向量

工作流1: 激活修补(因果追踪)

通过将干净激活修补到损坏运行中,识别哪些激活因果影响模型输出。

步骤

from transformer_lens import HookedTransformer, patching
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

# 1. 定义干净和损坏提示
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

# 2. 获取干净激活
_, clean_cache = model.run_with_cache(clean_tokens)

# 3. 定义指标(例如,logit差异)
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")

def metric(logits):
    return logits[0, -1, paris_token] - logits[0, -1, rome_token]

# 4. 修补每个位置和层
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])

for layer in range(model.cfg.n_layers):
    for pos in range(clean_tokens.shape[1]):
        def patch_hook(activation, hook):
            activation[0, pos] = clean_cache[hook.name][0, pos]
            return activation

        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
        )
        results[layer, pos] = metric(patched_logits)

# 5. 可视化结果(层 x 位置热图)

清单

  • [ ] 定义最小差异的干净和损坏输入
  • [ ] 选择捕捉行为差异的指标
  • [ ] 缓存干净激活
  • [ ] 系统修补每个(层,位置)组合
  • [ ] 将结果可视化为热图
  • [ ] 识别因果热点

工作流2: 电路分析(间接对象识别)

复制“Interpretability in the Wild”中的IOI电路发现。

步骤

from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

# IOI任务:"When John and Mary went to the store, Mary gave a bottle to"
# 模型应预测"John"(间接对象)

prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)

# 1. 获取基准logits
logits, cache = model.run_with_cache(tokens)

john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")

# 2. 计算logit差异(IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit差异: {logit_diff.item():.3f}")

# 3. 通过头直接logit归因
def get_head_contribution(layer, head):
    # 将头输出投影到logits
    head_out = cache["z", layer][0, :, head, :]  # [pos, d_head]
    W_O = model.W_O[layer, head]  # [d_head, d_model]
    W_U = model.W_U  # [d_model, vocab]

    # 头对最终位置logits的贡献
    contribution = head_out[-1] @ W_O @ W_U
    return contribution[john_token] - contribution[mary_token]

# 4. 映射所有头
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        head_contributions[layer, head] = get_head_contribution(layer, head)

# 5. 识别顶级贡献头(名称移动器、备份名称移动器)

清单

  • [ ] 设置具有清晰IO/S令牌的任务
  • [ ] 计算基准logit差异
  • [ ] 按注意力头贡献分解
  • [ ] 识别关键电路组件(名称移动器、S抑制、归纳)
  • [ ] 用消融实验验证

工作流3: 归纳头检测

找到实现[A][B]…[A] → [B]模式的归纳头。

from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

# 创建重复序列:[A][B][A] 应预测 [B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]])  # 任意令牌

_, cache = model.run_with_cache(repeated_tokens)

# 归纳头从最后一个[A]关注到第一个[B]
# 检查从位置2到位置1的注意力
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

for layer in range(model.cfg.n_layers):
    pattern = cache["pattern", layer][0]  # [head, q_pos, k_pos]
    # 从位置2到位置1的注意力
    induction_scores[layer] = pattern[:, 2, 1]

# 高分头是归纳头
top_heads = torch.topk(induction_scores.flatten(), k=5)

常见问题与解决方案

问题: 调试后钩子持久

# 错误: 旧钩子仍活跃
model.run_with_hooks(tokens, fwd_hooks=[...])  # 调试,添加新钩子
model.run_with_hooks(tokens, fwd_hooks=[...])  # 旧钩子还在!

# 正确: 始终重置钩子
model.reset_hooks()
model.run_with_hooks(tokens, fwd_hooks=[...])

问题: 令牌化陷阱

# 错误: 假设一致令牌化
model.to_tokens("Tim")  # 单个令牌
model.to_tokens("Neel")  # 变成"Ne" + "el"(两个令牌!)

# 正确: 显式检查令牌化
tokens = model.to_tokens("Neel", prepend_bos=False)
print(model.to_str_tokens(tokens))  # ['Ne', 'el']

问题: 分析中忽略LayerNorm

# 错误: 忽略LayerNorm
pre_activation = residual @ model.W_in[layer]

# 正确: 包含LayerNorm
ln_scale = model.blocks[layer].ln2.w
ln_out = model.blocks[layer].ln2(residual)
pre_activation = ln_out @ model.W_in[layer]

问题: 大型模型内存爆炸

# 使用选择性缓存
logits, cache = model.run_with_cache(
    tokens,
    names_filter=lambda n: "resid_post" in n or "pattern" in n,
    device="cpu"  # 在CPU上缓存
)

关键类参考

目的
HookedTransformer 带有钩子的主模型包装器
ActivationCache 类似字典的激活缓存
HookedTransformerConfig 模型配置
FactoredMatrix 高效因式矩阵操作

与SAELens集成

TransformerLens与SAELens集成,用于稀疏自编码器分析:

from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")

# 用SAE运行
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
sae_acts = sae.encode(cache["resid_pre", 8])

参考文档

有关详细API文档、教程和高级用法,请参阅references/文件夹:

文件 内容
references/README.md 概述和快速入门指南
references/api.md HookedTransformer、ActivationCache、HookPoints的完整API参考
references/tutorials.md 激活修补、电路分析、logit镜头的逐步教程

外部资源

教程

论文

官方文档

版本说明

  • v2.0: 移除了HookedSAE(移至SAELens)
  • v3.0 (alpha): TransformerBridge用于加载任何nn.Module