Transformer 的强大实力已经在诸多大型语言模型(LLM)上得到了证明,但该架构远非完美,也有很多研究者致力于改进这一架构,比如机器之心曾报道过的Reformer 和Infini-Transformer。
今天我们又将介绍另一种新型Transformer 架构:Differential Transformer(差分Transformer,简称Diff Transformer)。该架构来自微软研究院和清华大学,有四位共一作者:Tianzhu Ye、Li Dong、Yuqing Xia、Yutao Sun。

在Hacker News 及Twitter 等社交网络上,该论文都反响热烈,有网友表示差分Transformer 提出的改进简单又美丽,而带来的提升又非常显着。
甚至已有开发者做出了差分Transformer 的轻量实现!

差分Transformer 的轻量实现,
那么差分Transformer 弥补了原生Transformer 的哪些问题呢?如下图所示,Transformer 往往会过度关注不相关的上下文,该团队将此称为注意力噪声(attention noise)。而差分Transformer 则能放大对答案范围的注意力并消除噪音,从而增强上下文建模的能力。这就要用到该团队新提出的差分注意力机制(differential attention mechanism)了。

差分注意力机制可以消除注意力噪声,鼓励模型重点关注关键信息。该方法有些类似于电气工程中的降噪耳机和差分放大器。
下面我们就来详细了解一下差分Transformer 的设计思路。
差分Transformer
差分Transformer 是一种用于序列建模的基础模型架构。为了方便说明,他们使用了仅解码器(decoder-only)模型作为示例来描述该架构。
该模型堆叠了L 个Diff Transformer 层。给定一个输入序列x,将输入嵌入打包成X^0。输入会被进一步上下文化来获得输出X^L。每一层都由两个模块组成:一个差分注意力模块和之后的前向网络模块。
相比于Transformer,差分Transformer 的主要差别在于使用差分注意力替换了传统的softmax 注意力,同时保持整体宏观布局不变。此外,他们也参考LLaMA 采用了pre-RMSNorm 和SwiGLU 这两项改进措施。
差分注意力
差分注意力机制的作用是将查询、键和值向量映射成输出。这里使用查询和键向量来计算注意力分数,然后计算值向量的加权和。
此处的关键设计是使用一对softmax 函数来消除注意力分数的噪声。具体来说,给定输入X,首先将它们投射成查询、键和值Q_1、Q_2、K_1、K_2、V。然后差分注意力算子DiffAttn (・) 通过以下方式计算输出:
其中W^Q、W^K 、W^V 是参数,λ 是可学习的标量。为了同步学习动态,将标量λ 重新参数化为:
其中λ_q1、λ_k1、λ_q2、λ_k2 是可学习的向量,λ_init ∈ (0, 1) 是用于初始化λ 的常数。该团队通过经验发现,设置λ_init = 0.8 − 0.6 × exp (−0.3・(l − 1)) 在实践中效果很好,其中l ∈ [1, L] 表示层索引。它在实验中被用作默认策略。
他们也探索了另一种初始化策略:对所有层使用相同的λ_init(例如0.8)。如后面消融研究所示,使用不同的初始化策略时,性能相对稳健。
差分注意力利用两个softmax 注意力函数之间的差来消除注意力噪声。这个想法类似于电气工程中提出的差分放大器,其中两个信号之间的差用作输出,这样就可以消除输入的共模噪声。此外,降噪耳机的设计也基于类似的想法。
该团队也为差分注意力使用了多头机制。令h 表示注意力头的数量。他们对各个头使用不同的投影矩阵W^Q_i 、W^K_i 、W^V_i ,i ∈ [1, h]。标量λ 在同一层内的头之间共享。然后对头输出执行归一化,并投射成最终结果,如下所示:

其中λ_init 是(2) 式中的常数标量,W^O 是可学习的投影矩阵,LN (・) 是对每个头使用RMSNorm,Concat (・) 的作用是沿通道维度将头连接在一起。这里使用一个固定乘数(1 − λ_init)作为LN (・) 的缩放尺度,以使梯度与Transformer 对齐。
图2 使用了GroupNorm (・) 来强调LN (・) 独立应用于每个head。由于差分注意力往往具有更稀疏的模式,因此头之间的统计信息更加多样化。为了改进梯度的统计情况,LN (・) 算子会在连接操作之前对每个头进行归一化。

整体架构
其整体架构会堆叠L 层,其中每层包含一个多头差分注意力模块和一个前向网络模块。如此,便可将差分Transformer 层描述为:
其中LN (・) 是RMSNorm,SwiGLU (X) = (swish (XW^G) ⊙ XW_1) W_2,且W^G、W_1、W_2 是可学习的矩阵。
实验
该团队从以下角度评估了差分Transformer 在LLM 中的应用,包括对比评估、应用评估和消融研究。这里我们仅关注实验结果,更多实验过程请访问原论文。
语言建模评估
该团队评估了差分Transformer 的语言建模能力。为此,他们使用1T token 训练了一个3B 大小的差分Transformer 语言模型,并与之前的Transformer 语言模型做了比较。
结果见表1,其中报告的是在LM eval Harness 基准上的零样本结果。

可以看到,3B 规模下,差分Transformer 语言模型的表现优于之前的Transformer 语言模型。此外,实验也表明差分Transformer 在多种任务上都胜过Transformer,详见原论文附录。
与Transformer 的可扩展性比较
该团队也比较了新旧Transformer 的可扩展性。结果见图3,其中a 比较了模型规模方面的可扩展性,而b 则是训练token 数量方面的可扩展性。

可以看到,在这两个方面,差分Transformer 的可扩展性均优于常规Transformer:仅需后者65% 左右的模型大小或训练token 数量就能达到相媲美的性能。
长上下文评估
当3B 模型上下文长度增长至64K,模型的表现又如何呢?又使用另外1.5B token 训练了3B 版本的检查点模型之后,该团队发现随着上下文长度的增加,累积平均负对数似然(NLL)持续下降。差分Transformer 得到的NLL 值低于常规Transformer。见图4,这样的结果表明,差分Transformer 可以有效地利用不断增加的上下文。

关键信息检索
为了检验差分Transformer 检索关键信息的能力,该团队执行了Needle-In-A-Haystack(草堆找针)测试。
表2 给出了4K 上下文长度的情况,其中N 是针的数量,R 是查询引用的数量。可以看到,差分Transformer 的多针检索准确度高于常规Transformer,尤其是当针数量较多时,差分Transformer 的优势会更加明显。

那么当上下文长度提升至64K 时,又会如何呢?结果见图5,这里使用的上下文长度在8K 到64K 之间,使用了N = 8 和R = 1 的设置。

可以看到,在不同的上下文长度下,差分Transformer 能够保持相对稳定的性能。而当上下文长度越来越大时,常规Transformer 的性能会逐渐下降。
另外,表3 展示了分配给关键信息检索任务的答案范围和噪声上下文的注意力分数。该分数可代表模型保留有用信息、抵抗注意力噪声的能力。

可以看到,相比于常规Transformer,差分Transformer 能为答案范围分配更高的注意力分数,同时为注意力噪声分配更低的注意力分数。
上下文学习能力评估
该团队从两个角度评估模型的上下文学习能力,包括多样本分类和上下文学习的稳健性。
图6 展示了新旧Transformer 模型的多样本分类结果。结果表明,在不同的数据集和不同的演示样本数量上,差分Transformer 均稳定地优于Transformer。此外,差分Transformer 的平均准确度优势也很明显,从5.2% 到21.6% 不等。

图7 则展示了两种模型的上下文学习稳健性结果。该分析基于TREC 数据集,并且采用了两种提示词格式:示例随机排列(图7a)和按类别交替排列(图7b)。

在这两种设置下,差分Transformer 的性能方差要小得多。结果表明,新方法在上下文学习任务中更为稳健。相比之下,Transformer 容易受到顺序排列的影响,导致最佳结果与最差结果之间差距巨大。
上下文幻觉评估
该团队基于文本摘要和问答任务评估了模型的上下文幻觉现象。结果见表4。

可以看到,相比于常规Transformer,差分Transformer 在摘要和问答任务上的上下文幻觉更低。该团队表示,原因可能是差分Transformer 能更好地关注任务所需的基本信息,而不是无关上下文。
激活异常值分析
在LLM 中,一部分激活值明显大于大多数激活值的现象被称为激活异常值(activation outliers)。异常值导致训练和推理过程中模型量化困难。实验表明差分Transformer 可以降低激活异常值的幅度,从而可能实现更低的量化位宽。
表5 展示了两个训练得到Transformer 和差分Transformer 模型的激活值统计情况。这里分析了两种类型的激活,包括注意力logit(即pre-softmax 激活)和隐藏状态(即层输出)。可以看到,尽管中位数相似,但与Transformer 相比,差分Transformer 的较大激活值要低得多。这表明新方法产生的激活异常值较少。

图8 则展示了将注意力logit 量化到更低位的情况。这里使用的方案是:使用absmax 量化的动态后训练量化。其中,16 位配置表示未经量化的原始结果。模型逐步量化为8 位、6 位和4 位。这里报告的是在HellaSwag 上的零样本准确度,但该团队也指出在其它数据集上也有类似表现。

从图中可知,即使降低位宽,差分Transformer 也能保持较高性能。相较之下,常规Transformer 的准确度在6 位和4 位量化时会显着下降。这一结果表明,差分Transformer 本身就能缓解注意力分数中的激活异常值问题,从而可为低位FlashAttention 的实现提供新机会。
最后,该团队也进行了消融实验,证明了各个新设计的有效性。


