Sparse Feature Circuits: 发现并编辑语言模型中可解释的因果图
摘要
我们介绍了发现和应用稀疏特征电路的方法。这些电路是因果相关的人类可解释特征的子网络,用于解释语言模型的行为。之前工作中识别的电路由多义性和难以解释的单位(如注意力头或神经元)组成,因此不适合许多下游应用。相比之下,稀疏特征电路能够深入理解神经网络中未预料的机制。由于它们基于细粒度单元,稀疏特征电路对下游任务非常有用:我们提出了 SHIFT,通过去掉人类认为与任务无关的特征来改善分类器的泛化能力。最后,我们通过为自动发现的模型行为发现数千个稀疏特征电路,展示了一个完全无监督和可扩展的可解释性流程。
1 引言
可解释性研究的关键挑战在于以可扩展的方式解释神经网络(NNs)许多意想不到的行为。最近的许多研究是通过粗粒度模型组件来解释神经网络的行为,例如通过涉及某些归纳头在上下文学习中的作用或在事实回忆中的MLP模块。然而,这些组件通常是多义的,且难以解释,使得将机制上的见解应用于下游任务变得困难。另一方面,先前基于细粒度单元分析行为的方法试图使用研究者设计的数据将模型内部与研究者指定的机制假说相匹配。这些方法并不适合许多情况下,研究者无法预先预测模型如何内部实现其惊人行为。
我们提出使用扮演狭窄、可解释角色的细粒度组件来解释模型行为。为此,我们需要解决两个挑战:首先,我们必须识别适当的细粒度分析单元,因为像神经元这样的明显选择往往难以解释,而通过线性探测等监督方法发现的单元则需预先存在的假设。其次,我们必须解决在大量细粒度单元中搜索因果电路所带来的可扩展性问题。
我们利用在神经网络可解释性方面的词典学习的最近进展来解决第一个挑战。具体而言,我们使用稀疏自编码器(SAEs)来识别表示人类可解释概念的语言模型潜在空间的方向。然后,为了应对可扩展性挑战,我们采用线性近似,以高效识别与模型行为最具因果关联的 SAE 特征,以及这些特征之间的连接。最终,我们得到了一个稀疏特征电路,以解释模型行为是如何通过细粒度的人类可解释单元之间的交互而产生的。
稀疏特征电路可以在下游应用中有效使用。我们介绍了一种技术,称为稀疏人类可解释特征修剪(SHIFT; §4),通过外科手术般地去除对非预期信号的敏感性,从而改变语言模型分类器的泛化能力。与之前关于虚假线索去除的工作不同——该工作利用消歧的数据隔离出虚假信号——SHIFT 利用可解释性和人类判断来识别非预期信号。因此,我们通过在最坏情况下去偏见分类器来展示 SHIFT,在这种情况下,非预期信号(性别)与目标标签(职业)有着完美的预测关系。
最后,我们通过 Michaud 等人 (2023) 提出的聚类方法,自动发现数千种狭窄的语言模型行为(例如,将“to”预测为不定式宾语或在日期中预测逗号),并随后自动发现这些行为的特征电路,展示了我们方法的可扩展性(§5)。
我们的贡献总结如下(图 1):
- 一种可扩展的方法来发现稀疏特征电路。我们通过在一系列主谓一致任务上发现和评估特征电路来验证我们的方法。
- SHIFT,一种去除语言模型分类器对非预期信号敏感性的技术,即使没有数据来隔离这些信号。
- 一个完全无监督的流程,用于计算数千种自动发现的语言模型行为的特征电路,相关信息可在 feature-circuits.xyz 查看。
本文将与源代码、数据以及训练好的自编码器一同发布。
2 形式化
使用稀疏自编码器进行特征解耦。在神经网络的可解释性研究中,一个基本挑战是单个神经元通常难以解释。因此,许多可解释性研究者最近转向了稀疏自编码器(SAEs),这是一种无监督技术,用于识别大量可解释的神经网络潜变量。对于具有潜空间 和激活 的模型组件,SAE 计算一个分解
为将输入近似重建 分解为特征 的稀疏和,以及一个 SAE 误差项 。这里, 表示 SAE 的宽度,特征 是单位向量,特征激活 是一组稀疏系数, 是偏置。SAE 在一个目标上进行训练,该目标促进小的重建误差 ,同时仅使用一组稀疏的特征激活 。我们的方法并不会为了电路发现舍弃误差项 ,而是通过将其优雅地纳入我们的稀疏特征电路来处理这些误差项;这为模型行为提供了一种有原则的分解,将可解释特征的贡献与尚未被 SAE 捕获的误差成分区分开来。
在本研究中,我们利用以下一系列 SAEs:
- 我们为 Pythia-70M 的每个子层(注意力层、MLP、残差流和嵌入)训练的一套 SAEs。我们紧密遵循 Bricken 等人 (2023) 的工作,使用 ReLU 线性编码器 和稀疏维度 ,并训练 SAEs 以最小化 L2 重构损失和 L1 正则化项的组合,后者促进稀疏性。有关我们的 Pythia SAEs 及其训练的详细信息,请参见附录 B.1。
- 开源的 Gemma Scope SAEs,适用于开源权重 Gemma-2-2B 模型的所有子层(不包括嵌入)。这些 SAEs 使用 Jump-ReLU 线性编码器和 。有关 Gemma Scope SAEs 的详细信息,请参见附录 B.2。
可扩展地训练更好的 SAEs 是一个活跃的研究领域,开源 SAEs 的现成可用性便是明证。因此,我们的目标是——在给定一套已训练的 SAEs 的情况下——可扩展地将它们应用于理解神经网络的行为;我们将扩展 SAEs 本身视为超出范围。
利用线性近似归因因果效应。设 为通过计算图(例如,神经网络)计算出的实值度量;设 表示图中的一个节点。遵循先前的研究,我们通过其对 的间接效应(IE)量化 在一对输入 上的重要性:
在这里, 是计算 时 所取的值,而 表示在计算 时,通过手动将 设置为 对 的计算进行干预时, 的值。例如,给定输入 "The teacher"和 "The teachers",我们有度量 ,这是从语言模型输出的对数概率差。如果a是某个特定神经元的激活值,那么 的较大值表明该神经元对模型在这一对输入中决定输出“is”与“are”具有很高的影响力。
我们通常希望计算大量模型组件 的 IE,但使用 (2) 进行高效计算是不可行的。因此,我们采用线性近似 (2),以便能够并行计算多个 。最简单的这种近似是归因补丁(attribution patching),采用了一阶泰勒展开
其仅用两次前向和一次反向传播就并行地每个 的 (2)。
为改进估计的质量,我们可以使用一个基于综合梯度的更昂贵但更精确的近似:
其中 (4) 中的求和在 等分的 。当一个节点位于另一个节点下游时,无法并行处理这两个节点,但对于那些相互独立的任意多个节点,可以进行并行处理。因此,计算 相对于 的额外成本在 和 的计算图的串行深度上呈线性增长。
上述讨论适用于我们有一对干净输入和补丁输入的情况,我们希望理解将某个特定节点从其干净值更改为补丁值的影响。但是在某些情况下(例如,§4, 5),我们只有一个输入 。在这种情况下,我们使用零消融,利用间接效应 ,通过将 设置为 。我们从 (3) 和 (4) 得到 的修改公式,方法是将 替换为 。
3 稀疏特征电路发现
3.1 方法
假设我们给定一个语言模型 、不同 的子模块(例如注意力输出、MLP 输出和残差流向量等,如 §2 中)的 SAEs、一个由对比样本对 或单个输入 组成的数据集 ,以及一个依赖于 处理来自 的数据时输出的度量 。例如,图 2 显示了 由数量不同的输入对组成,而 则是 输出的动词形式在补丁输入和干净输入上正确性的对数概率差的情况。
将 SAE 特征视为模型的一部分。我们方法的一个关键思想是,通过对语言模型中的各种隐藏状态 应用分解 (1),我们可以将特征激活 和 SAE 误差 视为语言模型计算的一部分。因此,我们可以将模型表示为一个计算图 ,其中节点对应于特定 token 位置的特征激活或 SAE 误差。
近似每个节点的 IE。设 为 或 (见 §2)。对于 中的每个节点 和输入 ,我们计算 ;然后我们在 上取平均,生成一个分数 ,并过滤出 的节点,其中 为某个节点阈值的选择。
与之前的工作一致,我们发现 准确地估计了 SAE 特征和 SAE 误差的 IEs,除了第 0 层 MLP 和早期残差流层的节点,此时 低估了真实 IE。我们发现 显著改善了这些组件相比 的准确性,因此我们在下面的实验中使用它。有关线性近似质量的更多信息,请参见附录 H。
近似边的 IE。使用类似的线性近似,我们还计算计算图中边的平均 IE。尽管这个想法很简单,但数学上涉及的内容相对复杂,因此我们将细节留给附录 A.1。在计算完这些 IEs 后,我们过滤出绝对 IE 超过某个边阈值 的边。
跨 token 位置和示例的聚合。对于模板数据,其中相同位置的 tokens 起着一致的作用(见 §3.2, 3.3),我们在示例之间计算节点/边的平均效应。对于非模板数据(§4, 5),我们首先跨 token 位置对对应的节点/边的效果进行求和,然后计算示例间的平均值。有关更多信息,请参见附录 A.2。
实际考虑。为了高效计算我们方法所需的梯度,出现了各种实际困难。我们通过结合停止梯度、传递梯度和高效雅可比向量乘积计算技巧来解决这些问题;请参见附录 A.3。
3.2 发现并评估主谓一致的稀疏特征电路
为了评估我们的方法,我们在 Pythia-70M 和 Gemma-2-2B 上发现用于四种主谓一致任务的变体(见表 1)的稀疏特征电路(以下简称“特征电路”)。具体来说,我们根据 Finlayson 等人 (2021) 的研究调整数据,生成的数据集由仅在主语的语法数上有所不同的对比输入对组成;模型的任务是选择合适的动词变形。我们对电路进行可解释性、忠实性和完整性评估。对于每一标准,我们与通过在稀疏特征上应用我们的方法发现的神经元电路进行比较;在这种设置中,没有误差项 ε。在评估特征电路的忠实性和完整性时,我们使用数据集的测试集,该集由未用于发现电路的对比对组成。
可解释性。对于 Pythia SAEs,我们请人力众包工作者评估随机特征、随机神经元、来自我们特征电路的特征以及来自我们神经元电路的神经元的可解释性。众包工作者认为稀疏特征的可解释性显著高于神经元,并且参与我们电路的特征可解释性也高于随机抽样的特征(附录 F)。这一结果验证了先前的研究发现,即 SAE 特征的可解释性显著高于神经元(Bricken 等,2023)。对于 Gemma-2 SAEs,读者可以参阅 Lieberum 等(2024),该研究发现这些 SAEs 的特征可解释性与通过其他最先进技术训练的特征相当。
忠实性。给定一个电路 C 和一个指标 m,令 m(C) 表示在运行模型时,所有不在 C 中的节点均为均值消融状态下,从 D 中输入的平均值。我们通过 m(C)−m(∅)m(M )−m(∅) 来衡量电路的忠实性,其中 ∅ 表示空电路,M 表示完整模型。直观上,这一指标捕捉了我们的电路解释模型性能的比例,相对于均值消融完整模型(表示模型在得到任务信息但没有具体输入时的“先验”性能)。我们发现早期模型层中的组件通常参与处理特定的标记。在实践中,数据集中训练集(用于发现电路)和测试集(用于评估)中的输入不包含相同的标记,这使得评估我们电路的早期部分质量变得困难。因此,我们忽略电路的前 1/3,只评估后 2/3。我们绘制特征电路和神经元电路的忠实性图,参数范围遍及节点阈值 TN(见图 3)。我们发现小型特征电路解释了模型行为的很大一部分:Pythia-70M 和 Gemma-2-2B 的大多数性能分别仅由 100 和 500 个节点解释。相比之下,解释一半性能大约需要 1500 和 50000 个神经元。然而,由于 SAE 误差节点是高维且粗粒度的,因此无法与神经元进行公平比较;因此我们也绘制了移除所有 SAE 误差节点后的特征电路的忠实性,或者移除所有注意力和 MLP 误差节点后的情况。不出所料,我们发现移除残差流 SAE 误差节点会严重干扰模型并限制其最大性能;移除 MLP 和注意力错误节点的影响较小。
完整性。我们的电路是否无法捕捉模型行为的一部分?我们通过电路补集 M \ C 的忠实性来衡量这一点(见图 3)。我们观察到,仅通过消融我们特征电路中的几个节点就可以消除模型的任务性能,即使在我们保留所有 SAE 误差的情况下亦然。相对而言,做到这一点需要数百(对于 Pythia)或数千(对于 Gemma)个神经元。