稀疏自动编码器 (SAE) 是一种越来越常见的解释机器学习模型的工具(尽管 SAE 自 1997 年就已经存在)。
机器学习模型和 LLM 正变得越来越强大和有用,但它们仍然是黑匣子,我们不明白它们是如何完成任务的。了解它们的工作原理应该会有所帮助。
SAE 帮助我们将模型的计算分解为可理解的组件。最近,LLM 可解释性研究员 Adam 发表了一篇博客文章,直观地解释了 SAE 的工作原理。
可解释性的挑战
神经网络最自然的组成部分是单个神经元。遗憾的是,单个神经元并不容易与单个概念相对应,例如学术引文、英语对话、HTTP 请求和韩语文本。在神经网络中,概念由神经元的组合表示,这称为叠加。
发生这种情况是因为世界上的许多变量天生就很稀疏。
例如,一位名人的出生地可能出现在不到十亿分之一的训练标记中,但现代法学硕士仍然可以学习这一事实以及有关世界的许多其他知识。训练数据中单个事实和概念的数量大于模型中神经元的数量,这可能是叠加发生的原因。
稀疏自动编码器 (SAE) 是一种近年来越来越常用的技术,用于将神经网络分解为可理解的组件。SAE 的设计灵感来自神经科学中的稀疏编码假设。如今,SAE 已成为理解人工神经网络最有前途的工具之一。SAE 类似于标准自动编码器。
传统的自动编码器是一种压缩和重建输入数据的神经网络。
例如,如果输入是一个 100 维向量(100 个值的列表);自动编码器首先将输入通过编码器层,将其压缩为 50 维向量,然后将此压缩的编码表示输入到解码器中以获得 100 维输出向量。重建过程通常并不完美,因为压缩过程使重建任务非常困难。

标准自动编码器的示意图,具有 1x4 输入向量、1x2 中间状态向量和 1x4 输出向量。单元格的颜色表示激活值。输出是输入的不完美重建。
解释稀疏自动编码器
稀疏自动编码器的工作原理
稀疏自动编码器将输入向量转换为中间向量,该中间向量的维度可以高于、等于或低于输入。当用于 LLM 时,中间向量的维度通常高于输入。在这种情况下,如果没有额外的约束,任务就很容易,SAE 可以使用单位矩阵完美地重建输入,而不会出现任何意外。但是,我们将添加约束,其中之一是向训练损失添加稀疏性惩罚,这会迫使 SAE 创建稀疏中间向量。
例如,我们可以将 100 维输入扩展为 200 维编码表示向量,并且可以训练 SAE 使编码表示中只有大约 20 个非零元素。

稀疏自动编码器的图表。请注意,中间激活是稀疏的,只有 2 个非零值。
我们使用 SAE 作为神经网络中的中间激活,该网络可能包含多层。在前向传递过程中,每层中以及每层之间都有中间激活。
例如,GPT-3 有 96 层。在正向传递过程中,输入中的每个 token 都有一个 12,288 维向量(一个包含 12,288 个数字的列表)。这个向量积累了模型在处理每一层时用于预测下一个 token 的所有信息,但它并不透明,很难理解它包含的信息。
我们可以使用 SAE 来理解这个中间激活。SAE 基本上是“矩阵 → ReLU 激活 → 矩阵”。
例如,如果 GPT-3 SAE 的扩展因子为 4,并且其输入激活为 12,288 维,则其 SAE 编码表示为 49,512 维(12,288 x 4)。第一个矩阵是形状为 (12,288, 49,512) 的编码器矩阵,第二个矩阵是形状为 (49,512, 12,288) 的解码器矩阵。通过将 GPT 激活与编码器相乘并使用 ReLU,我们可以得到 49,512 维的 SAE 编码稀疏表示,因为 SAE 损失函数鼓励稀疏性。
一般来说,我们的目标是在 SAE 表示中拥有少于 100 个非零值。通过将 SAE 表示与解码器相乘,我们得到了 12,288 维的重建模型激活。这种重建并不完美地匹配原始 GPT 激活,因为稀疏性约束使得很难实现完美匹配。
通常,SAE 仅用于模型中的一个位置。例如,我们可以在第 26 层和第 27 层之间的中间激活上训练 SAE。为了分析 GPT-3 所有 96 层输出中包含的信息,我们可以训练 96 个单独的 SAE - 每个层输出一个。如果我们还想分析每层内的各种中间激活,我们将需要数百个 SAE。为了获取这些 SAE 的训练数据,我们将向 GPT 模型输入大量不同的文本并收集每个选定位置的中间激活。
下面提供了 SAE 的参考实现。变量用形状注释,这是 Noam 的想法,参见:@/shape--good--style- 。请注意,不同的 SAE 实现通常具有不同的偏差、规范化方案或初始化方案,以最大限度地提高性能。最常见的添加是对解码器向量范数的某种约束。有关更多详细信息,请访问以下实现:
火炬
torch.nn 作为 nn
#D =,F =
# 例如,如果 = 12288 且 = 49152
# 然后 .shape = (12288,) 和 ..shape = (12288, 49152)
类别(nn.):
单层。
def(自身,:int,:int):
极好的 ()。()
自我。=
自我。=
自我。= nn。(,,偏见=真)
自我。= nn。(,,偏见=真)
def (self, : torch.) -> torch.:
nn.ReLU())(self.())
def (self, on_F: torch.) -> torch.:
自我。(on_F)
def (self, : torch.) -> 元组 [torch., torch.]:
on_F = 自身.()
= 自身。(on_F)
, on_F
标准自动编码器的损失函数基于输入重建的准确性。引入稀疏性的最直接方法是向 SAE 的损失函数添加稀疏性惩罚。计算此惩罚的最常见方法是取 SAE 编码表示(而非 SAE 权重)的 L1 损失并将其乘以 L1 系数。此 L1 系数是 SAE 训练中的关键超参数,因为它决定了实现稀疏性和保持重建准确性之间的权衡。
请注意,我们这里不是针对可解释性进行优化。相反,可解释的 SAE 特征是优化稀疏性和重构的副作用。下面是一个参考损失函数。
# B = 批量大小, D = , F =
def (:,:torch。,:float) -> torch。:
, 在_BF = . ()
_BD = (-).pow (2)
_B = . (_BD, 'BD -> B', '总和')
= _B.平均值()
= * on_BF.sum()
损失 = +
损失

稀疏自动编码器前向传递的示意图。
这是通过稀疏自动编码器的单次前向传递。首先,我们有一个 1x4 模型向量。然后我们将其乘以 4x8 编码器矩阵以获得 1x8 编码向量,然后应用 ReLU 将负值变为零。这个编码向量是稀疏的。之后,我们将其乘以 8x4 解码器矩阵以获得 1x4 不完美重建的模型激活。
假设的 SAE 功能演示
理想情况下,SAE 表示中的每个有效值都对应一个可理解的组件。
我们用一个例子来说明这一点。假设一个 12,288 维向量 [1.5, 0.2, -1.2, ...] 在 GPT-3 看来代表“ ”(金毛猎犬)。SAE 是一个形状为 (49,512, 12,288) 的矩阵,但我们也可以将其视为 49,512 个向量的集合,每个向量的形状为 (1, 12,288)。如果 SAE 解码器的 317 个向量学习到了与 GPT-3 相同的“ ”概念,那么解码器向量也大致等于 [1.5, 0.2, -1.2, ...]。
每当 SAE 激活函数的元素 317 非零时,就会将与“ ”对应的向量(取决于 317 元素的大小)添加到重构激活函数中。从机械可解释性角度来看,这可以简洁地描述为“解码器向量对应于残差流空间中特征的线性表示”。
也就是说,具有 49,512 维编码表示的 SAE 有 49,512 个特征。这些特征由相应的编码器和解码器向量组成。编码器向量用于检测模型内部概念,同时尽量减少其他概念的干扰,而解码器向量用于表示“真实”的特征方向。研究人员的实验发现,每个特征的编码器和解码器特征都不同,中位数余弦相似度为 0.5。下图中,三个红色框对应单个特征。

稀疏自动编码器的示意图,其中三个红色框对应于 SAE 特征 1,绿色框对应于特征 4。每个特征都有一个 1x4 编码器向量、1x1 特征激活和 1x4 解码器向量。重建的激活仅使用来自 SAE 特征 1 和 4 的解码器向量构建。如果红色框表示“红色”,绿色框表示“球”,则模型很可能表示“红球”。
那么我们如何知道假设的特征 317 代表什么呢?目前的做法是找到能够最大化特征激活的输入,并对其可解释性做出直观的响应。能够激活每个特征的输入通常是可解释的。
例如,使用 训练的 SAE 发现,与金门大桥、神经科学和热门旅游景点相关的文本和图像激活了不同的 SAE 特征。其他特征由不太明显的概念激活,例如,使用 训练的 SAE 的特征由“用于修饰句子主语的相关从句或介词短语的最后一个标记”的概念激活。
由于 SAE 解码器向量具有与 LLM 中间激活相同的形状,因此只需将解码器向量添加到模型激活中即可进行因果干预。可以通过将解码器向量乘以扩展因子来调整此干预的强度。当研究人员将“金门大桥”SAE 解码器向量添加到 的激活中时,它被迫在每个响应中都提到“金门大桥”。
下面是使用假设特征 317 的因果干预的参考实现。与“金门大桥”类似,这种非常简单的干预迫使 GPT-3 模型在每次回应中都提到“金毛猎犬”。
def (: torch., : torch., scale: float) -> torch.:
D = [317, :]
= D * 比例
= +
稀疏自动编码器的评估挑战
使用 SAE 的主要挑战之一是评估。我们可以训练稀疏自动编码器来解释语言模型,但我们没有可测量的自然语言底层真实表示。目前,评估是主观的,基本上是“我们查看一系列特征的激活输入,然后对这些特征的可解释性做出直觉判断。”这是可解释性领域的一个主要限制。
研究人员发现了一些常见的代理指标,它们似乎与特征可解释性相对应。最常用的是 L0 和 Loss。L0 是 SAE 编码中间表示中非零元素的平均数量。Loss 将 GPT 的原始激活替换为重建的激活,并测量不完美重建结果的额外损失。这两个指标通常是相互权衡的,因为 SAE 可能会选择降低重建精度的解决方案来提高稀疏性。
在比较 SAE 时,一种常见的方法是绘制两个变量的图,然后检查它们之间的权衡。为了实现更好的权衡,许多新的 SAE 方法(例如 Gated SAE 和 TopK SAE)修改了稀疏性惩罚。下图来自 Gated SAE 论文。Gated SAE 用红线表示,位于图的左上角,表明它在这个权衡中表现更好。

门控 SAE L0 和损失
SAE 指标有多种难度等级。L0 和 Loss 是两个代理指标。但是,我们在训练期间不使用它们,因为 L0 不可微,而且在 SAE 训练期间计算 Loss 的计算成本很高。相反,我们的训练损失由 L1 惩罚项和重建内部激活的准确性决定,而不是它们对下游损失的影响。
训练损失函数并不直接对应于代理指标,代理指标只是特征可解释性主观评估的一个代理。由于我们真正的目标是“理解模型是如何工作的”,而主观可解释性评估只是一个代理,因此会存在另一层不匹配。LLM 中的一些重要概念可能不容易解释,我们在盲目优化可解释性时可能会忽略这些概念。
总结
在可解释性领域还有很长的路要走,但 SAE 确实向前迈出了一步。SAE 支持有趣的新应用,例如用于查找引导向量(如金门大桥的引导向量)的无监督方法。SAE 还可以帮助我们更轻松地找到语言模型中的循环,这可用于消除模型中不必要的偏差。
SAE 能够找到可解释的特征(即使目标只是识别激活模式),这一事实表明它们可以揭示一些有意义的东西。还有证据表明 LLM 实际上正在学习一些有意义的东西,而不仅仅是记忆表面的统计规律。
SAE 还代表了 等公司一直追求的“机器学习模型的 MRI”的早期里程碑。SAE 尚未提供完美的理解,但它可用于检测不良行为。SAE 和 SAE 评估的主要挑战并非不可克服,许多研究人员正在研究这一主题。
有关稀疏自动编码器的进一步介绍,请参阅 Colab 笔记本:


