405B大模型也能线性化!斯坦福MIT最新研究,0.2%训练量让线性注意力提分20+

  新智元报道

  编辑:alan

  近日,来自斯坦福、MIT 等机构的研究人员推出了低秩线性转换方法,让传统注意力无缝转移到线性注意力,仅需 0.2% 的参数更新即可恢复精度,405B 大模型两天搞定!

  生产级大模型应用线性注意力的方法,来了。

  线性 Attention(包括 RNN 系列),再也不用困在几B参数的范围内娱乐了。

  一套方法,即可线性化现有各种量级的 Transformer 模型,上至 Llama 3.1 405B,也只需要十来张显卡在两天内搞定!

  这就是斯坦福、MIT 等科研机构推出的低秩线性转换 LoLCATs(Low-rank Linear Conversion with Attention Transfer)。

  论文与代码:https://github.com/HazyResearch/lolcats

  应用 LoLCATs,可以实现传统注意力(softmax)到线性注意力的无缝转移,

  且转换后仅需开销很低的微调(LoRA),0.2% 的参数更新即可恢复精度,对比同类的线性注意力模型或方法, 5-shot MMLU 直接提高了 20 分左右!

  也就是说,在几乎不损失 Transformer 大模型语言能力的基础上,将 LLM 的计算复杂度从二次方降到了线性。

  线性 Attention 一事,前人之述备矣,然则,能够真正做大做强,还是第一次。

  尤其具有实用价值的是,LoLCATs 实现了极小的开销和接近原始模型的性能。

  LoLCATs 的线性化转换只需两个步骤:

  首先使用线性 Attention 的形式替换原始 Attention 部分,并利用简单的 MSE 损失训练新增的参数,以近似 softmax 注意力;

  然后通过低成本的微调(LoRA)来进一步提高模型的精度。

  为了实现可扩展性,作者采用更精细的「block by block」训练,将 LLM 的每k层看成一个 block,尽在块内联合训练注意力,以提高分层注意力匹配。

  就如上图所表示的那样,一个羊驼(Llama)可以看成多个小刺猬叠在一起,每个小刺猬拥有独特的用于线性化的参数,并且相互之间可以独立训练。

  LoLCATS 加速 LLM

  为了避免昂贵的训练成本,研究者们一直在不断探索两个方面:

  make models fast 与 create fast models

  诸如 Mamba、RWKV、TransNormer、Hawk、 Griffin 和 StripedHyena 等高效的 subquadratic models 不断出现,

  而关于将流行的 LLM 线性化的工作也让我们眼前一亮。

  但是线性化 LLM 往往伴随着模型质量的显著降低,你甚至能通过 MMLU 的测试分数猜出一个模型是不是传统的 Attention 架构,或者传统 Attention 块在模型中的占比。

  另外,从实用的角度讲,只有拿下了生产级别的大模型,线性化的道路才能真正与传统 Transformer 平分秋色。

  预备知识

  先打基础:为什么要线性化?

  正常的 softmax 注意力可以表示为下图上面的公式:

  由于 softmax 的缘故,只能先算Q乘K,导致中间缓存和计算量随序列长度的平方增长;

  线性化就是设计俩函数来近似 softmax,从而把公式转化成下面的形式。

  此时Q和K不需要绑在一起了,就可以先算K乘V,这个顺序的改变导致中间缓存和计算量随向量长度的平方增长,而相对于序列长度是线性关系。

  这就是线性化的意思,这样的 Attention 也就不惧怕长序列带来的压力了。

  开始线性化

  本文中,作者的主要想法是向线性化 Transformer 中添加三个简单的概念:

  1. Learnable (Linear) Attentions:可学习的(线性)注意力

  2. Low-rank Adaptation:低秩适配 3. Layer-wise Optimization:分层优化

  Learnable Attentions

  首先训练线性注意力来模拟和替换 softmax 注意力。这种「注意力转移」的灵感来自作者之前的一篇工作:Hedgehog。

  论文地址:https://arxiv.org/pdf/2402.04347

  如何设计设计精妙复杂的函数来近似 softmax 注意力?

  作者表示:与其让人类煞费苦心,不如交给 AI 自己去学!

  相比于 Hedgehog 中只使用可学习的线性注意力,作者在 LoLCATs 中,将其推广为可学习的线性注意力和 + 滑动窗口。

  研究人员将线性和 softmax 注意力统一在一个层中,训练一些新增的参数以从整体上近似 softmax 注意力。

  对于N个 token 的序列,前W个 token 用于计算 softmax 注意力,后N-W 个 token 用于计算线性注意力,然后将这些值组合。

  在 Hedgehog 中,作者通过 KL 散度来训练特征图以匹配注意力权重,而本文改为在注意力层的输出上使用 MSE 失。

  这绕过了 Hedgehog 的一个限制:需要将所有注意力权重实例化为监督目标。

  相反,LoLCATs 可以使用 FlashAttention 来计算 softmax 注意力输出,并将线性化注意力的内存消耗保持在O(N)。

  只需将这些特征图插入到每个现有的注意力中,即可创建线性化的 LLM。冻结所有其他权重,只训练这些特征图,对于 7B 的 LLM 来说,只需要调整 0.2% 的参数。

  Low-rank Adaptation

  之前的线性化工作,通常需要一个比较昂贵的端到端训练阶段。

  但在 LoLCATs 这里,可以通过简单地将低秩适应(LoRA)应用于注意力的 QKVO 权重来恢复模型的性能。

  冻结所有其他内容,只训练 LoRA 权重,在某些自然语言数据上,最大限度地减少 LLM 输出的 next-token 预测损失。

  Layer-wise Optimization

  大多数情况下,只需要以上两步就搞定了。但对于像 Llama 3.1 405B 这种规模的模型来说,还需要努力一下。

  通过简单地联合优化所有层,可以成功地线性化 7B 到 70B 参数范围的 LLM,但整体训练时,后面层的 MSE 会比前面的层更大。

  当模型变得更大更深时,MSE 升级为了微调 Llama 3.1 405B 的真正问题。

  为此,研究人员使用了更精细的逐块训练,将 Llama 3.1 405B 分成多个k层块,并仅在每个块内联合训练注意力。

  当使用一些线性化数据并行训练所有模块时,只需为每个块预先计算 LLM 的隐藏状态。

  可以调节k来平衡并行训练的速度与预计算的内存,并将隐藏状态保存到磁盘。不需要花哨的成本模型,对于 50M token 的线性化来说:

  k = 1 时,需要 2 字节 × 126 层 × 50M token × 16384(hidden size)= 200TB 的磁盘空间来存储隐藏状态。

  而 k = 9 时,磁盘空间的需求将减少为 22TB,这时仍然能在单个 GPU 上并行训练每个块(9 层)。

  ——后者显然更友好一点,所以作者将 Llama 3.1 405B 的 126 层拆分为 14 个 9 层块,在 14 个 GPU 上并行进行注意力的线性化,过程仅需 5 个小时。然后用 LoRA 将它们全部拼接在一起,就得到了最终模型。

  实验结果

  质量恢复

  下表给出了 6 个流行的 LLM 评估任务的结果。

  与最近的一些线性化方法相比,LoLCATs 显著提高了不同任务和不同 LLM 的质量和训练效率。

  尽管只训练了 0.2% 的模型参数(40M token),LoLCATs 将线性化与原始模型的性能差距平均缩小了 80% 以上,token to model 的效率提高了 500~2500 倍。

  在 7B 这个量级上,LoLCATs 优于所有的线性注意力(包括 RNN 系列)模型:Mamba、RWKV、TransNormer、Hawk、 Griffin 和 StripedHyena。

  挑战 405B 大模型

  最后,作者使用 LoLCATs 将线性化扩展到 Llama 3.1 70B 和更大的 405B 模型。

  与之前的线性化方法相比,首先是质量上的显著改进。通过控制相同的线性 + 滑动窗口层,对于 Llama 3.1 70B,在5-shot MMLU 上的精度实现了 39 点的提升,对于 Llama 3.1 405B,同样实现了 38.3 分的改进。

  其次是训练效率的提高,在单个 8x80GB H100 上线性化 Llama 3.1 70B 仅需 18 个小时,而线性化 Llama 3.1 405B 所花费的时间比之前用于 8B 模型的方法还要少。

  参考资料:

  https://x.com/simran_s_arora/status/1845909074774475125