新智元报道
编辑:alan
近日,来自斯坦福、MIT 等机构的研究人员推出了低秩线性转换方法,让传统注意力无缝转移到线性注意力,仅需 0.2% 的参数更新即可恢复精度,405B 大模型两天搞定!
生产级大模型应用线性注意力的方法,来了。
线性 Attention(包括 RNN 系列),再也不用困在几B参数的范围内娱乐了。
一套方法,即可线性化现有各种量级的 Transformer 模型,上至 Llama 3.1 405B,也只需要十来张显卡在两天内搞定!
这就是斯坦福、MIT 等科研机构推出的低秩线性转换 LoLCATs(Low-rank Linear Conversion with Attention Transfer)。
论文与代码:https://github.com/HazyResearch/lolcats
应用 LoLCATs,可以实现传统注意力(softmax)到线性注意力的无缝转移,
且转换后仅需开销很低的微调(LoRA),0.2% 的参数更新即可恢复精度,对比同类的线性注意力模型或方法, 5-shot MMLU 直接提高了 20 分左右!
也就是说,在几乎不损失 Transformer 大模型语言能力的基础上,将 LLM 的计算复杂度从二次方降到了线性。
线性 Attention 一事,前人之述备矣,然则,能够真正做大做强,还是第一次。
尤其具有实用价值的是,LoLCATs 实现了极小的开销和接近原始模型的性能。
LoLCATs 的线性化转换只需两个步骤:
首先使用线性 Attention 的形式替换原始 Attention 部分,并利用简单的 MSE 损失训练新增的参数,以近似 softmax 注意力;
然后通过低成本的微调(LoRA)来进一步提高模型的精度。
为了实现可扩展性,作者采用更精细的「block by block」训练,将 LLM 的每k层看成一个 block,尽在块内联合训练注意力,以提高分层注意力匹配。
就如上图所表示的那样,一个羊驼(Llama)可以看成多个小刺猬叠在一起,每个小刺猬拥有独特的用于线性化的参数,并且相互之间可以独立训练。
LoLCATS 加速 LLM
为了避免昂贵的训练成本,研究者们一直在不断探索两个方面:
make models fast 与 create fast models
诸如 Mamba、RWKV、TransNormer、Hawk、 Griffin 和 StripedHyena 等高效的 subquadratic models 不断出现,
而关于将流行的 LLM 线性化的工作也让我们眼前一亮。
但是线性化 LLM 往往伴随着模型质量的显著降低,你甚至能通过 MMLU 的测试分数猜出一个模型是不是传统的 Attention 架构,或者传统 Attention 块在模型中的占比。
另外,从实用的角度讲,只有拿下了生产级别的大模型,线性化的道路才能真正与传统 Transformer 平分秋色。
预备知识
先打基础:为什么要线性化?
正常的 softmax 注意力可以表示为下图上面的公式:
由于 softmax 的缘故,只能先算Q乘K,导致中间缓存和计算量随序列长度的平方增长;
线性化就是设计俩函数来近似 softmax,从而把公式转化成下面的形式。
此时Q和K不需要绑在一起了,就可以先算K乘V,这个顺序的改变导致中间缓存和计算量随向量长度的平方增长,而相对于序列长度是线性关系。
这就是线性化的意思,这样的 Attention 也就不惧怕长序列带来的压力了。
开始线性化
本文中,作者的主要想法是向线性化 Transformer 中添加三个简单的概念:
1. Learnable (Linear) Attentions:可学习的(线性)注意力
2. Low-rank Adaptation:低秩适配 3. Layer-wise Optimization:分层优化
Learnable Attentions
首先训练线性注意力来模拟和替换 softmax 注意力。这种「注意力转移」的灵感来自作者之前的一篇工作:Hedgehog。
论文地址:https://arxiv.org/pdf/2402.04347
如何设计设计精妙复杂的函数来近似 softmax 注意力?
作者表示:与其让人类煞费苦心,不如交给 AI 自己去学!
相比于 Hedgehog 中只使用可学习的线性注意力,作者在 LoLCATs 中,将其推广为可学习的线性注意力和 + 滑动窗口。
研究人员将线性和 softmax 注意力统一在一个层中,训练一些新增的参数以从整体上近似 softmax 注意力。
对于N个 token 的序列,前W个 token 用于计算 softmax 注意力,后N-W 个 token 用于计算线性注意力,然后将这些值组合。
在 Hedgehog 中,作者通过 KL 散度来训练特征图以匹配注意力权重,而本文改为在注意力层的输出上使用 MSE 失。
这绕过了 Hedgehog 的一个限制:需要将所有注意力权重实例化为监督目标。
相反,LoLCATs 可以使用 FlashAttention 来计算 softmax 注意力输出,并将线性化注意力的内存消耗保持在O(N)。
只需将这些特征图插入到每个现有的注意力中,即可创建线性化的 LLM。冻结所有其他权重,只训练这些特征图,对于 7B 的 LLM 来说,只需要调整 0.2% 的参数。
Low-rank Adaptation
之前的线性化工作,通常需要一个比较昂贵的端到端训练阶段。
但在 LoLCATs 这里,可以通过简单地将低秩适应(LoRA)应用于注意力的 QKVO 权重来恢复模型的性能。
冻结所有其他内容,只训练 LoRA 权重,在某些自然语言数据上,最大限度地减少 LLM 输出的 next-token 预测损失。
Layer-wise Optimization
大多数情况下,只需要以上两步就搞定了。但对于像 Llama 3.1 405B 这种规模的模型来说,还需要努力一下。
通过简单地联合优化所有层,可以成功地线性化 7B 到 70B 参数范围的 LLM,但整体训练时,后面层的 MSE 会比前面的层更大。
当模型变得更大更深时,MSE 升级为了微调 Llama 3.1 405B 的真正问题。
为此,研究人员使用了更精细的逐块训练,将 Llama 3.1 405B 分成多个k层块,并仅在每个块内联合训练注意力。
当使用一些线性化数据并行训练所有模块时,只需为每个块预先计算 LLM 的隐藏状态。
可以调节k来平衡并行训练的速度与预计算的内存,并将隐藏状态保存到磁盘。不需要花哨的成本模型,对于 50M token 的线性化来说:
k = 1 时,需要 2 字节 × 126 层 × 50M token × 16384(hidden size)= 200TB 的磁盘空间来存储隐藏状态。
而 k = 9 时,磁盘空间的需求将减少为 22TB,这时仍然能在单个 GPU 上并行训练每个块(9 层)。
——后者显然更友好一点,所以作者将 Llama 3.1 405B 的 126 层拆分为 14 个 9 层块,在 14 个 GPU 上并行进行注意力的线性化,过程仅需 5 个小时。然后用 LoRA 将它们全部拼接在一起,就得到了最终模型。
实验结果
质量恢复
下表给出了 6 个流行的 LLM 评估任务的结果。
与最近的一些线性化方法相比,LoLCATs 显著提高了不同任务和不同 LLM 的质量和训练效率。
尽管只训练了 0.2% 的模型参数(40M token),LoLCATs 将线性化与原始模型的性能差距平均缩小了 80% 以上,token to model 的效率提高了 500~2500 倍。
在 7B 这个量级上,LoLCATs 优于所有的线性注意力(包括 RNN 系列)模型:Mamba、RWKV、TransNormer、Hawk、 Griffin 和 StripedHyena。
挑战 405B 大模型
最后,作者使用 LoLCATs 将线性化扩展到 Llama 3.1 70B 和更大的 405B 模型。
与之前的线性化方法相比,首先是质量上的显著改进。通过控制相同的线性 + 滑动窗口层,对于 Llama 3.1 70B,在5-shot MMLU 上的精度实现了 39 点的提升,对于 Llama 3.1 405B,同样实现了 38.3 分的改进。
其次是训练效率的提高,在单个 8x80GB H100 上线性化 Llama 3.1 70B 仅需 18 个小时,而线性化 Llama 3.1 405B 所花费的时间比之前用于 8B 模型的方法还要少。
参考资料: