Pyvene因果干预技能Skill pyvene-interventions

Pyvene因果干预技能是一种用于对PyTorch神经网络模型进行因果分析和干预的技术,支持因果追踪、激活修补、交换干预训练等方法,用于理解和改进模型行为,适用于大模型微调和AI应用场景。关键词:因果干预、PyTorch、神经网络、模型解释、激活修补、深度学习、AI工具。

AI应用 0 次安装 0 次浏览 更新于 3/21/2026

名称: pyvene-干预 描述: 提供使用pyvene的声明式干预框架对PyTorch模型进行因果干预的指导。用于进行因果追踪、激活修补、交换干预训练或测试模型行为的因果假设。 版本: 1.0.0 作者: Orchestra Research 许可证: MIT 标签: [因果干预, pyvene, 激活修补, 因果追踪, 可解释性] 依赖: [pyvene>=0.1.8, torch>=2.0.0, transformers>=4.30.0]

pyvene: 神经网络的因果干预

pyvene是斯坦福NLP的库,用于对PyTorch模型进行因果干预。它提供了一个声明式的、基于字典的框架,用于激活修补、因果追踪和交换干预训练——使干预实验可重现和可共享。

GitHub: stanfordnlp/pyvene (840+ stars) 论文: pyvene: 通过干预理解和改进PyTorch模型的库 (NAACL 2024)

何时使用pyvene

使用pyvene当您需要:

  • 执行因果追踪(ROME风格定位)
  • 运行激活修补实验
  • 进行交换干预训练(IIT)
  • 测试模型组件的因果假设
  • 通过HuggingFace共享/重现干预实验
  • 处理任何PyTorch架构(不仅仅是transformers)

考虑替代方案当:

  • 您需要探索性激活分析 → 使用 TransformerLens
  • 您想训练/分析SAEs → 使用 SAELens
  • 您需要在大型模型上远程执行 → 使用 nnsight
  • 您需要更低层次的控制 → 使用 nnsight

安装

pip install pyvene

标准导入:

import pyvene as pv

核心概念

IntervenableModel

主要类,包装任何PyTorch模型以启用干预能力:

import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载基础模型
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# 定义干预配置
config = pv.IntervenableConfig(
    representations=[
        pv.RepresentationConfig(
            layer=8,
            component="block_output",
            intervention_type=pv.VanillaIntervention,
        )
    ]
)

# 创建可干预模型
intervenable = pv.IntervenableModel(config, model)

干预类型

类型 描述 使用案例
VanillaIntervention 在运行之间交换激活 激活修补
AdditionIntervention 向基础运行添加激活 转向、消融
SubtractionIntervention 减去激活 消融
ZeroIntervention 将激活归零 组件敲除
RotatedSpaceIntervention DAS可训练干预 因果发现
CollectIntervention 收集激活 探测、分析

组件目标

# 可干预的可用组件
components = [
    "block_input",      # transformer块输入
    "block_output",     # transformer块输出
    "mlp_input",        # MLP输入
    "mlp_output",       # MLP输出
    "mlp_activation",   # MLP隐藏激活
    "attention_input",  # 注意力输入
    "attention_output", # 注意力输出
    "attention_value_output",  # 注意力值向量
    "query_output",     # 查询向量
    "key_output",       # 键向量
    "value_output",     # 值向量
    "head_attention_value_output",  # 每头值
]

工作流1: 因果追踪(ROME风格)

通过破坏输入和恢复激活来定位事实关联存储位置。

分步指南

import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")

# 1. 定义清洁和破坏的输入
clean_prompt = "The Space Needle is in downtown"
corrupted_prompt = "The ##### ###### ## ## ########"  # 噪声

clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")

# 2. 获取清洁激活(源)
with torch.no_grad():
    clean_outputs = model(**clean_tokens, output_hidden_states=True)
    clean_states = clean_outputs.hidden_states

# 3. 定义恢复干预
def run_causal_trace(layer, position):
    """在特定层和位置恢复清洁激活。"""
    config = pv.IntervenableConfig(
        representations=[
            pv.RepresentationConfig(
                layer=layer,
                component="block_output",
                intervention_type=pv.VanillaIntervention,
                unit="pos",
                max_number_of_units=1,
            )
        ]
    )

    intervenable = pv.IntervenableModel(config, model)

    # 运行干预
    _, patched_outputs = intervenable(
        base=corrupted_tokens,
        sources=[clean_tokens],
        unit_locations={"sources->base": ([[[position]]], [[[position]]])},
        output_original_output=True,
    )

    # 返回正确token的概率
    probs = torch.softmax(patched_outputs.logits[0, -1], dim=-1)
    seattle_token = tokenizer.encode(" Seattle")[0]
    return probs[seattle_token].item()

# 4. 扫描层和位置
n_layers = model.config.n_layer
seq_len = clean_tokens["input_ids"].shape[1]

results = torch.zeros(n_layers, seq_len)
for layer in range(n_layers):
    for pos in range(seq_len):
        results[layer, pos] = run_causal_trace(layer, pos)

# 5. 可视化(层 x 位置热图)
# 高值表示因果重要性

检查清单

  • [ ] 准备带有目标事实关联的清洁提示
  • [ ] 创建破坏版本(噪声或反事实)
  • [ ] 为每个(层,位置)定义干预配置
  • [ ] 运行修补扫描
  • [ ] 在热图中识别因果热点

工作流2: 用于电路分析的激活修补

测试哪些组件对特定行为是必要的。

分步指南

import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# IOI任务设置
clean_prompt = "When John and Mary went to the store, Mary gave a bottle to"
corrupted_prompt = "When John and Mary went to the store, John gave a bottle to"

clean_tokens = tokenizer(clean_prompt, return_tensors="pt")
corrupted_tokens = tokenizer(corrupted_prompt, return_tensors="pt")

john_token = tokenizer.encode(" John")[0]
mary_token = tokenizer.encode(" Mary")[0]

def logit_diff(logits):
    """IO - S logit差异。"""
    return logits[0, -1, john_token] - logits[0, -1, mary_token]

# 在每层修补注意力输出
def patch_attention(layer):
    config = pv.IntervenableConfig(
        representations=[
            pv.RepresentationConfig(
                layer=layer,
                component="attention_output",
                intervention_type=pv.VanillaIntervention,
            )
        ]
    )

    intervenable = pv.IntervenableModel(config, model)

    _, patched_outputs = intervenable(
        base=corrupted_tokens,
        sources=[clean_tokens],
    )

    return logit_diff(patched_outputs.logits).item()

# 找出哪些层重要
results = []
for layer in range(model.config.n_layer):
    diff = patch_attention(layer)
    results.append(diff)
    print(f"Layer {layer}: logit diff = {diff:.3f}")

工作流3: 交换干预训练(IIT)

训练干预以发现因果结构。

分步指南

import pyvene as pv
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")

# 1. 定义可训练干预
config = pv.IntervenableConfig(
    representations=[
        pv.RepresentationConfig(
            layer=6,
            component="block_output",
            intervention_type=pv.RotatedSpaceIntervention,  # 可训练
            low_rank_dimension=64,  # 学习64维子空间
        )
    ]
)

intervenable = pv.IntervenableModel(config, model)

# 2. 设置训练
optimizer = torch.optim.Adam(
    intervenable.get_trainable_parameters(),
    lr=1e-4
)

# 3. 训练循环(简化)
for base_input, source_input, target_output in dataloader:
    optimizer.zero_grad()

    _, outputs = intervenable(
        base=base_input,
        sources=[source_input],
    )

    loss = criterion(outputs.logits, target_output)
    loss.backward()
    optimizer.step()

# 4. 分析学习到的干预
# 旋转矩阵揭示因果子空间
rotation = intervenable.interventions["layer.6.block_output"][0].rotate_layer

DAS(分布式对齐搜索)

# 低秩旋转找到可解释的子空间
config = pv.IntervenableConfig(
    representations=[
        pv.RepresentationConfig(
            layer=8,
            component="block_output",
            intervention_type=pv.LowRankRotatedSpaceIntervention,
            low_rank_dimension=1,  # 找到1D因果方向
        )
    ]
)

工作流4: 模型转向(诚实LLaMA)

在生成期间转向模型行为。

import pyvene as pv
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# 加载预训练的转向干预
intervenable = pv.IntervenableModel.load(
    "zhengxuanzenwu/intervenable_honest_llama2_chat_7B",
    model=model,
)

# 生成并转向
prompt = "Is the earth flat?"
inputs = tokenizer(prompt, return_tensors="pt")

# 干预在生成期间应用
outputs = intervenable.generate(
    inputs,
    max_new_tokens=100,
    do_sample=False,
)

print(tokenizer.decode(outputs[0]))

保存和共享干预

# 本地保存
intervenable.save("./my_intervention")

# 从本地加载
intervenable = pv.IntervenableModel.load(
    "./my_intervention",
    model=model,
)

# 在HuggingFace上共享
intervenable.save_intervention("username/my-intervention")

# 从HuggingFace加载
intervenable = pv.IntervenableModel.load(
    "username/my-intervention",
    model=model,
)

常见问题及解决方案

问题: 错误的干预位置

# 错误: 不正确的组件名称
config = pv.RepresentationConfig(
    component="mlp",  # 无效!
)

# 正确: 使用确切的组件名称
config = pv.RepresentationConfig(
    component="mlp_output",  # 有效
)

问题: 维度不匹配

# 确保源和基有兼容的形状
# 对于位置特定的干预:
config = pv.RepresentationConfig(
    unit="pos",
    max_number_of_units=1,  # 干预单个位置
)

# 明确指定位置
intervenable(
    base=base_tokens,
    sources=[source_tokens],
    unit_locations={"sources->base": ([[[5]]], [[[5]]])},  # 位置5
)

问题: 大型模型的内存问题

# 使用梯度检查点
model.gradient_checkpointing_enable()

# 或在较少的组件上干预
config = pv.IntervenableConfig(
    representations=[
        pv.RepresentationConfig(
            layer=8,  # 单层而不是所有层
            component="block_output",
        )
    ]
)

问题: LoRA集成

# pyvene v0.1.8+ 支持LoRA作为干预
config = pv.RepresentationConfig(
    intervention_type=pv.LoRAIntervention,
    low_rank_dimension=16,
)

关键类参考

目的
IntervenableModel 干预的主要包装器
IntervenableConfig 配置容器
RepresentationConfig 单个干预规范
VanillaIntervention 激活交换
RotatedSpaceIntervention 可训练的DAS干预
CollectIntervention 激活收集

支持的模型

pyvene适用于任何PyTorch模型。测试于:

  • GPT-2(所有尺寸)
  • LLaMA / LLaMA-2
  • Pythia
  • Mistral / Mixtral
  • OPT
  • BLIP(视觉语言)
  • ESM(蛋白质模型)
  • Mamba(状态空间)

参考文档

详细的API文档、教程和高级用法,请参见references/文件夹:

文件 内容
references/README.md 概述和快速入门指南
references/api.md IntervenableModel、干预类型、配置的完整API参考
references/tutorials.md 因果追踪、激活修补、DAS的分步教程

外部资源

教程

论文

官方文档

与其他工具的比较

特性 pyvene TransformerLens nnsight
声明式配置
HuggingFace共享
可训练干预 有限
任何PyTorch模型 仅transformers
远程执行 是(NDIF)