ICML2024高分!魔改注意力,让小模型能打两倍大的模型

  彩云科技团队投稿

  量子位公众号 QbitAI

  改进 Transformer 核心机制注意力,让小模型能打两倍大的模型!

  ICML 2024 高分论文,彩云科技团队构建 DCFormer 框架,替换 Transformer 核心组件多头注意力模块(MHA),提出可动态组合的多头注意力(DCMHA)。

  DCMHA 解除了 MHA 注意力头的查找选择回路和变换回路的固定绑定,让它们可以根据输入动态组合,从根本上提升了模型的表达能力。

  可以近似理解为,原来每层有固定的H个注意力头,现在用几乎同样的参数量和算力,可按需动态组合出多至 HxH 个注意力头。

  DCMHA 即插即用,可在任何 Transformer 架构中替换 MHA,得到通用、高效和可扩展的新架构 DCFormer。

  这项工作由来自北京邮电大学、AI 创业公司彩云科技的研究人员共同完成。

  研究人员用在 DCFormer 基础上打造的模型 DCPythia-6.9B,在预训练困惑度和下游任务评估上都优于开源 Pythia-12B。

  DCFormer 模型在性能上与那些计算量是其 1.7-2 倍的 Transformer 模型相当。

  多头注意力模块有何局限?

  大模型的 scaling law 告诉我们,随着算力的提升,模型更大、数据更多,模型效果会越来越好。虽然还没有人能明确说明这条路的天花板有多高,能否达到 AGI,但这确实是目前大家最普遍的做法。

  但除此以外,另一个问题同样值得思考:目前绝大多数大模型都基于 Transformer,它们都是用一个一个 Transformer 块像搭积木一样搭起来的,那作为积木块的 Transformer 本身,还有多大的改进提升空间?

  这是模型结构研究要回答的基本问题,也正是彩云科技和北京邮电大学联合完成的 DCFormer 这项工作的出发点。

  在 Transformer 的多头注意力模块(MHA)中,各个注意力头彼此完全独立的工作。

  这个设计因其简单易实现的优点已在实践中大获成功,但同时也带来注意力分数矩阵的低秩化削弱了表达能力、注意力头功能的重复冗余浪费了参数和计算资源等一些弊端。基于此,近年来有一些研究工作试图引入某种形式的注意力头间的交互。

  根据 Transformer 回路理论,在 MHA 中 ,每个注意力头的行为由 WQ、WK、WV、WO 四个权重矩阵刻画(其中 WO 由 MHA 的输出投影矩阵切分得到)。

  其中,WQWK 叫做 QK 回路(或叫查找选择回路),决定从当前 token 关注上下文中的哪个(些)token,例如:

  WOWV 叫做 OV 回路(或叫投影变换回路),决定从关注到的 token 取回什么信息(或投影什么属性)写入当前位置的残差流,进而预测下一个 token。例如:

  研究人员注意到,查找(从哪拿)和变换(拿什么)本来是独立的两件事,理应可以分别指定并按需自由组合(就像在 SQL 查询中,WHERE 后的选择条件和 SELECT 后的属性投影是分开写的一样),MHA 硬把它们放到一个注意力头的 QKOV 里“捆绑销售”,限制了灵活性和表达能力。

  例如,假设有个模型存在注意力头A、B、C其 QK 和 OV 回路能够完成上面的例子=,那换成:

  需要交叉组合现有注意力头的 QK 和 OV 回路,模型就可能“转不过弯儿”了(经研究人员系统构造的合成测试集验证,<=6B 的中小尺寸模型在这类看似简单的任务上确实表现不佳)。

  动态组合多头注意力长啥样?

  以此为出发点,本文研究团队在 MHA 中引入compose 操作

  如下图所示,得到 DCMHA:

  △图 1. DCMHA 总体结构

  将 QWQ 和 KWK 算出的注意力分数矩阵 AS 和注意力权重矩阵 AW,与 VWV 相乘之前,对其在 num_heads 维上做线性映射得到新的矩阵A’,通过不同的线性映射矩阵(composition map),以实现各种注意力头组合的效果。

  例如图2(c)中将 head 3 和 7 的 QK 回路与 head 1 的 OV 回路组合在一起,形成一个“新的”注意力头。

  △图 2. 8 个注意力头的简化的典型 composition map 的功能,浅色表示大值

  为了最大限度的增强表达能力,研究人员希望映射矩阵由输入动态生成,即动态决定注意力头怎样组合。

  但他们要生成的映射矩阵不是一个,而是对序列中每对源位置的 query Qi 和目的位置的 key Kj,都要生成这样一个矩阵,计算开销和显存占用都将难以接受。

  为此,他们进一步将映射矩阵分解为一个输入无关的静态矩阵 Wb、一个低秩矩阵 w1w2 和一个对角矩阵 Diag (wg)之和,分别负责基础组合、注意力头间的有限方式(即秩R<=2)的动态组合和头自身的动态门控(见图2(d)和图3(b))。其中后两个矩阵由Q矩阵和K矩阵动态生成。

  在不牺牲效果的前提下,将计算和参数复杂度降低到几乎可以忽略的程度(详见论文中复杂度分析)。再结合 JAX 和 PyTorch 实现层面的优化,让 DCFormer 可以高效训练和推理。

  △图 3. Compose 的计算

  效果如何?

  规模扩展

  评估一个架构的好坏,研究人员关注的最核心指标是算力转化为智能的效率(或叫性能算力比),即投入单位算力能带来的模型性能提升——花更少的算力,得到更好的模型。

  从图 4 和图 5 的 scaling law 曲线(在对数坐标下,每个模型架构的损失随算力的变化可画出一条近似直线,损失越低,模型越好)可以看出,DCFormer 可以达到 1.7~2 倍算力的 Transformer 模型的效果,即算力智能转化率提升了 1.7~2 倍。

  △图 4. Transformer 和 DCFormer 的规模扩展效果

  △图 5. Pythia 和 DCPythia 的规模扩展效果

  怎么理解这个提升幅度呢?

  自 2017 年 Transformer 诞生至今,从改进性能算力比的角度,GLU MLP 和旋转位置编码 RoPE 是经大量实践验证普适有效且被广泛采用的为数不多的两项架构改进。

  在原始 Transformer 中加入这两项改进的架构也叫 Transformer++,Llama、Mistral 等最强开源模型均采用该架构。无论 Transformer 还是 Transformer++ 架构,都可通过 DCMHA 获得显著改进。

  在 1.4B 模型规模下,DCMHA 的改进幅度大于 Transformer++ 的两项改进之和,且扩展性更好(图 4 下蓝绿线和黑线的对比,DCMHA 的改进幅度随算力增加衰减的更慢,以及图 4 和图 5 的对比)。

  可以说,DCFormer 让 Transformer 的能力又跃上一个新台阶。

  下游任务评测

  研究团队训练了 DCPythia-2.8B 和 DCPythia-6.9B 两个模型在主流 NLP 下游任务上进行测评并和同规模的开源模型 Pythia 进行比较(训练采用和 Pythia 完全相同超参数设置)。

  △表 1. DCFormer 和 Pythia 在下游任务中的表现

  从表 1 中可以看出,DCPythia-2.8B 和 6.9B 不仅在 Pile 验证集上的 ppl 更低,而且在大部分下游任务上都显著超过了 Pythia,DCPythia6.9B 在 ppl 和下游任务上的平均准确率甚至超过了 Pythia-12B

  DCFormer++2.8B 相对于 DCPythia-2.8B 有进一步的提升,验证了 DCMHA 和 Lllama 架构结合的有效性。

  训练和推理速度

  虽然引入 DCMHA 会带来额外的训练和推理开销,但是从表 2 中可以看出 DCFormer++ 的训练速度是 Transformer++ 的 74.5%-89.2%,推理速度则是 81.1%-89.7%,而且随着模型参数的增长,额外的计算开销会逐渐降低。

  △表 2. Transformer++ 和 DCFormer++ 的训练和推理速度对比

  训练速度是在 TPU v3 pod,序列长度为 2048,batch_size 为 1k 的情况下对比得到的;推理速度是在 A100 80G GPU 上进行评测的,输入长度 1024,生成长度 128。

  消融实验

  结果如下:

  △表 3. DCMHA 的消融实验

  从表 3 中可以看出以下几点:

  • 虽然加入静态的组合权重就可以降低 ppl,但引入动态的组合权重可以进一步降低 ppl,说明了动态组合的必要性。
  • 低秩动态组合比动态门控的效果更好。
  • 只用 query-wise 或者 key-wise 的动态组合得到的 ppl 相当,与 DCFormer++ 的差距很小。
  • 在 softmax 后做注意力头组合比在 softmax 前做更有效,可能是因为 softmax 后的概率能更直接影响输出。
  • 动态组合权重的秩无需设置过大,也说明了组合权重的低秩性。

  此外,研究人员还通过增加局部注意力层的比例和只用 query-wise 动态组合的方式去进一步减少训练和推理开销,详见论文 Table 10。

  总的来说,研究团队有两点总结。

  关于动态权重:近期 Mamba,GLA,RWKV6,HGRN 等 SSM 和线性注意力/RNN 的工作,通过引入动态(input-dependent)权重的方式,追赶上了 Transformer++,但 DCFormer 用动态组合注意力头的方式说明了在使用 softmax 注意力的情况下,通过引入动态权重也可以大幅提升 Transformer++ 的效果。

  关于模型架构创新:这项工作表明,如果存在一个具有极限算力智能转化效率的“理想模型架构”,当前的 Transformer 架构虽已非常强大,但距离这个理想架构很可能还存在很大的差距,仍有广阔的提升空间。因此,除了堆算力堆数据的大力出奇迹路线,模型架构创新同样大有可为。

  研究团队还表示,彩云科技会率先在旗下产品彩云天气、彩云小译、彩云小梦上应用 DCformer。

  有关更多研究细节,可参阅原始论文。

  ICML2024 论文链接:https://icml.cc/virtual/2024/poster/34047

  Arxiv 论文链接:https://arxiv.org/abs/2405.08553

  代码链接:https://github.com/Caiyun-AI/DCFormer