名称: pyvene-干预 描述: 提供使用pyvene的声明式干预框架对PyTorch模型进行因果干预的指导。用于进行因果追踪、激活修补、交换干预训练或测试模型行为的因果假设。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [因果干预, pyvene, 激活修补, 因果追踪, 可解释性] 依赖: [pyvene>=0.1.8, torch>=2.0.0, transformers>=4.30.0]
pyvene: 神经网络的因果干预
pyvene是斯坦福NLP的库,用于对PyTorch模型进行因果干预。它提供了一个声明式的、基于字典的框架,用于激活修补、因果追踪和交换干预训练——使干预实验可重现和可共享。
GitHub: stanfordnlp/pyvene (840+ stars) 论文: pyvene: 通过干预理解和改进PyTorch模型的库 (NAACL 2024)
何时使用pyvene
使用pyvene当您需要:
- 执行因果追踪(ROME风格定位)
- 运行激活修补实验
- 进行交换干预训练(IIT)
- 测试模型组件的因果假设
- 通过HuggingFace共享/重现干预实验
- 处理任何PyTorch架构(不仅仅是transformers)
考虑替代方案当:
- 您需要探索性激活分析 → 使用 TransformerLens
- 您想训练/分析SAEs → 使用 SAELens
- 您需要在大型模型上远程执行 → 使用 nnsight
- 您需要更低层次的控制 → 使用 nnsight
安装
pip install pyvene
标准导入:
import pyvene as pv
核心概念
IntervenableModel
主要类,包装任何PyTorch模型以启用干预能力:
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载基础模型
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 定义干预配置
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.VanillaIntervention,
)
]
)
# 创建可干预模型
intervenable = pv.IntervenableModel(config, model)
干预类型
| 类型 | 描述 | 使用案例 |
|---|---|---|
VanillaIntervention |
在运行之间交换激活 | 激活修补 |
AdditionIntervention |
向基础运行添加激活 | 转向、消融 |
SubtractionIntervention |
减去激活 | 消融 |
ZeroIntervention |
将激活归零 | 组件敲除 |
RotatedSpaceIntervention |
DAS可训练干预 | 因果发现 |
CollectIntervention |
收集激活 | 探测、分析 |
组件目标
# 可干预的可用组件
components = [
"block_input", # transformer块输入
"block_output", # transformer块输出
"mlp_input", # MLP输入
"mlp_output", # MLP输出
"mlp_activation", # MLP隐藏激活
"attention_input", # 注意力输入
"attention_output", # 注意力输出
"attention_value_output", # 注意力值向量
"query_output", # 查询向量
"key_output", # 键向量
"value_output", # 值向量
"head_attention_value_output", # 每头值
]
工作流1: 因果追踪(ROME风格)
通过破坏输入和恢复激活来定位事实关联存储位置。
分步指南
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
# 1. 定义清洁和破坏的输入
clean_prompt = "The Space Needle is in downtown"
corrupted_prompt = "The ##### ###### ## ## ########" # 噪声
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
# 2. 获取清洁激活(源)
with torch.no_grad():
clean_outputs = model(**clean_tokens, output_hidden_states=True)
clean_states = clean_outputs.hidden_states
# 3. 定义恢复干预
def run_causal_trace(layer, position):
"""在特定层和位置恢复清洁激活。"""
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="block_output",
intervention_type=pv.VanillaIntervention,
unit="pos",
max_number_of_units=1,
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 运行干预
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
unit_locations={"sources->base": ([[[position]]], [[[position]]])},
output_original_output=True,
)
# 返回正确token的概率
probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
seattle_token = tokenizer.encode(" Seattle")[0]
return probs[seattle_token].item()
# 4. 扫描层和位置
n_layers = model.config.n_layer
seq_len = clean_tokens["input_ids"].shape[1]
results = torch.zeros(n_layers, seq_len)
for layer in range(n_layers):
for pos in range(seq_len):
results[layer, pos] = run_causal_trace(layer, pos)
# 5. 可视化(层 x 位置热图)
# 高值表示因果重要性
检查清单
- [ ] 准备带有目标事实关联的清洁提示
- [ ] 创建破坏版本(噪声或反事实)
- [ ] 为每个(层,位置)定义干预配置
- [ ] 运行修补扫描
- [ ] 在热图中识别因果热点
工作流2: 用于电路分析的激活修补
测试哪些组件对特定行为是必要的。
分步指南
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# IOI任务设置
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to"
corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"
clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")
john_token = tokenizer.encode(" John")[0]
mary_token = tokenizer.encode(" Mary")[0]
def logit_diff(logits):
"""IO - S logit差异。"""
return logits[0, -1, john_token] - logits[0, -1, mary_token]
# 在每层修补注意力输出
def patch_attention(layer):
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=layer,
component="attention_output",
intervention_type=pv.VanillaIntervention,
)
]
)
intervenable = pv.IntervenableModel(config, model)
_, patched_outputs = intervenable(
base=corrupted_tokens,
sources=[clean_tokens],
)
return logit_diff(patched_outputs.logits).item()
# 找出哪些层重要
results = []
for layer in range(model.config.n_layer):
diff = patch_attention(layer)
results.append(diff)
print(f"Layer {layer}: logit diff = {diff:.3f}")
工作流3: 交换干预训练(IIT)
训练干预以发现因果结构。
分步指南
import pyvene as pv
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("gpt2")
# 1. 定义可训练干预
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=6,
component="block_output",
intervention_type=pv.RotatedSpaceIntervention, # 可训练
low_rank_dimension=64, # 学习64维子空间
)
]
)
intervenable = pv.IntervenableModel(config, model)
# 2. 设置训练
optimizer = torch.optim.Adam(
intervenable.get_trainable_parameters(),
lr=1e-4
)
# 3. 训练循环(简化)
for base_input, source_input, target_output in dataloader:
optimizer.zero_grad()
_, outputs = intervenable(
base=base_input,
sources=[source_input],
)
loss = criterion(outputs.logits, target_output)
loss.backward()
optimizer.step()
# 4. 分析学习到的干预
# 旋转矩阵揭示因果子空间
rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer
DAS(分布式对齐搜索)
# 低秩旋转找到可解释的子空间
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8,
component="block_output",
intervention_type=pv.LowRankRotatedSpaceIntervention,
low_rank_dimension=1, # 找到1D因果方向
)
]
)
工作流4: 模型转向(诚实LLaMA)
在生成期间转向模型行为。
import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 加载预训练的转向干预
intervenable = pv.IntervenableModel.load(
"zhengxuanzenwu/intervenable_honest_llama2_chat_7B",
model=model,
)
# 生成并转向
prompt = "Is the earth flat?"
inputs = tokenizer(prompt, return_tensors="pt")
# 干预在生成期间应用
outputs = intervenable.generate(
inputs,
max_new_tokens=100,
do_sample=False,
)
print(tokenizer.decode(outputs[0]))
保存和共享干预
# 本地保存
intervenable.save("./my_intervention")
# 从本地加载
intervenable = pv.IntervenableModel.load(
"./my_intervention",
model=model,
)
# 在HuggingFace上共享
intervenable.save_intervention("username/my-intervention")
# 从HuggingFace加载
intervenable = pv.IntervenableModel.load(
"username/my-intervention",
model=model,
)
常见问题及解决方案
问题: 错误的干预位置
# 错误: 不正确的组件名称
config = pv.RepresentationConfig(
component="mlp", # 无效!
)
# 正确: 使用确切的组件名称
config = pv.RepresentationConfig(
component="mlp_output", # 有效
)
问题: 维度不匹配
# 确保源和基有兼容的形状
# 对于位置特定的干预:
config = pv.RepresentationConfig(
unit="pos",
max_number_of_units=1, # 干预单个位置
)
# 明确指定位置
intervenable(
base=base_tokens,
sources=[source_tokens],
unit_locations={"sources->base": ([[[5]]], [[[5]]])}, # 位置5
)
问题: 大型模型的内存问题
# 使用梯度检查点
model.gradient_checkpointing_enable()
# 或在较少的组件上干预
config = pv.IntervenableConfig(
representations=[
pv.RepresentationConfig(
layer=8, # 单层而不是所有层
component="block_output",
)
]
)
问题: LoRA集成
# pyvene v0.1.8+ 支持LoRA作为干预
config = pv.RepresentationConfig(
intervention_type=pv.LoRAIntervention,
low_rank_dimension=16,
)
关键类参考
| 类 | 目的 |
|---|---|
IntervenableModel |
干预的主要包装器 |
IntervenableConfig |
配置容器 |
RepresentationConfig |
单个干预规范 |
VanillaIntervention |
激活交换 |
RotatedSpaceIntervention |
可训练的DAS干预 |
CollectIntervention |
激活收集 |
支持的模型
pyvene适用于任何PyTorch模型。测试于:
- GPT-2(所有尺寸)
- LLaMA / LLaMA-2
- Pythia
- Mistral / Mixtral
- OPT
- BLIP(视觉语言)
- ESM(蛋白质模型)
- Mamba(状态空间)
参考文档
详细的API文档、教程和高级用法,请参见references/文件夹:
| 文件 | 内容 |
|---|---|
| references/README.md | 概述和快速入门指南 |
| references/api.md | IntervenableModel、干预类型、配置的完整API参考 |
| references/tutorials.md | 因果追踪、激活修补、DAS的分步教程 |
外部资源
教程
论文
- 定位和编辑GPT中的事实关联 - Meng et al. (2022)
- 推理时间干预 - Li et al. (2023)
- 野外的可解释性 - Wang et al. (2022)
官方文档
与其他工具的比较
| 特性 | pyvene | TransformerLens | nnsight |
|---|---|---|---|
| 声明式配置 | 是 | 否 | 否 |
| HuggingFace共享 | 是 | 否 | 否 |
| 可训练干预 | 是 | 有限 | 是 |
| 任何PyTorch模型 | 是 | 仅transformers | 是 |
| 远程执行 | 否 | 否 | 是(NDIF) |