陈丹琦团队新作:Llama-2上下文扩至128k,10倍吞吐量仅需1/6内存

  丰色发自凹非寺

  量子位公众号 QbitAI

  陈丹琦团队刚刚发布了一种新的 LLM 上下文窗口扩展方法:

  它仅用 8k 大小的 token 文档进行训练,就能将 Llama-2 窗口扩展至 128k。

  最重要的是,在这个过程中,只需要原来1/6 的内存,模型就获得了 10 倍吞吐量。

  除此之外,它还能大大降低训练成本:

  用该方法对 7B 大小的羊驼 2 进行改造,只需要一块 A100 就能搞定。

  团队表示:希望这个方法有用、好用,为未来的 LLM 们提供廉价又有效的长上下文能力。

  目前,模型和代码都已在 HuggingFace 和 GitHub 上发布。

  只需添加两个组件

  这个方法名叫 CEPE,全称“并行编码上下文扩展(Context Expansion with Parallel Encoding)”。

  作为轻量级框架,它可用于扩展任何预训练和指令微调模型的上下文窗口。

  对于任何预训练的仅解码器语言模型,CEPE 通过添加两个小组件来实现扩展:

  一个是小型编码器,用于对长上下文进行块编码;

  一个是交叉注意力模块,插入到解码器的每一层,用于关注编码器表示。

  完整架构如下:

  在这个示意图中,编码器模型并行编码上下文的 3 个额外块,并与最终隐藏表示进行连接,然后作为解码器交叉注意力层的输入。

  在此,交叉注意力层主要关注解码器模型中自注意力层和前馈层之间的编码器表示。

  通过仔细选择无需标记的训练数据,CEPE 就帮助模型具备了长上下文能力,并且也擅长文档检索。

  作者介绍,这样的 CEPE 主要包含 3 大优势:

  (1)长度可泛化

  因为它不受位置编码的约束,相反,它的上下文是分段编码的,每一段都有自己的位置编码。

  (2)效率高

  使用小型编码器和并行编码来处理上下文可以降低计算成本。

  同时,由于交叉注意力仅关注编码器最后一层的表示,而仅使用解码器的语言模型则需要缓存每个层每个 token 的键-值对,所以对比起来,CEPE 需要的内存大大减少。

  (3)降低训练成本

  与完全微调方法不同,CEPE 只调整编码器和交叉注意力,同时保持大型解码器模型冻结。

  作者介绍,通过将 7B 解码器扩充为具有 400M 编码器和交叉注意力层的模型(总计 14 亿参数),用一块 80GB 的 A100 GPU 就可以完成。

  困惑度持续降低

  团队将 CEPE 应用于 Llama-2,并在 200 亿 token 的 RedPajama 过滤版本上进行训练(仅为 Llama-2 预训练预算的1%)。

  首先,与 LLAMA2-32K 和 YARN-64K 这两种完全微调的模型相比,CEPE 在所有数据集上都实现了更低或相当的困惑度,同时具有更低的内存使用率和更高的吞吐量。

  在将上下文提升到 128k 时(远超其 8k 训练长度),CEPE 的困惑度更是持续保持降低,同时保持低内存状态。

  相比之下,Llama-2-32K 和 YARN-64K 不仅不能推广到其训练长度之外,还伴随着内存成本显著增加。

  其次,检索能力增强。

  如下表所示:

  通过使用检索到的上下文,CEPE 可以有效改善模型困惑度,性能优于 RePlug。

  值得注意的是,即使让段落k=50(训练是 60),CEPE 仍会继续改善困惑度。

  这表明 CEPE 可以很好地转移到检索增强设置,而全上下文解码器模型在这个能力上却退化了。

  第三,开放域问答能力显著超越。

  如下图所示,CEPE 在所有数据集和段落k参数上都大幅优于其他模型,且不像别的模型那样,k值越来越大之后性能明显下降。

  这也表明,CEPE 对大量冗余或不相关的段落并不敏感。

  所以总结一下就是,与大多数其他解决办法相比,CEPE 在上述所有任务上都能以低得多的内存和计算成本胜出。

  最后,作者在这些基础上,提出了专门用于指令调优模型的 CEPE-Distilled(CEPED)。

  它仅使用未标记的数据来扩展模型的上下文窗口,通过辅助 KL 散度损失将原始指令调整模型的行为提炼为新架构,由此无需管理昂贵的长上下文指令跟踪数据。

  最终,CEPED 可以在保留指令理解能力的同时,扩展 Llama-2 的上下文窗口,提高模型长文本性能。

  团队介绍

  CEPE 一共 3 位作者。

  一作为颜和光(Howard Yen),普林斯顿大学计算机科学专业硕士生在读。

  二作为高天宇,同校博士生在读,清华本科毕业。

  他们都是通讯作者陈丹琦的学生。

  论文原文:

  https://arxiv.org/abs/2402.16617

  参考链接:

  https://twitter.com/HowardYen1/status/1762474556101661158