Mamba正式被ICLR拒收!“年度最佳技术原理解读”却火了

  丰色发自凹非寺

  量子位公众号 QbitAI

  悬着的心终于死了:

  被尊为Transformer 挑战者的 Mamba,已正式被 ICLR 拒绝。

  (之前被“初拒”后在学术圈引起轩然大波,转为“待定(Decision Pending)”状态)

  但这位“顶流”的热度岂受影响?

  这不,一篇关于它的最新通俗解读(作者:Jack Cook,牛津互联网研究院研究员,曾在 MIT、英伟达、微软工作),刚刚诞生,还在被网友们疯狂点赞收藏。

  有人甚至称它为:

到目前为止的年度最佳(解读)。

  咱也不能错过。

  以下为原文精华传送:

  背景:S4 架构

  Mamba 的架构主要基于 S4,一种最新的状态空间模型(SSM,state space model)架构。

  其主要思想如下:

  在较高层次上,S4 学习如何通过中间状态 h (t) 将输入x(t) 映射到输出 y (t) 上。

  在此,由于 SSM 被设计于很好地处理连续数据,例如音频、传感器数据和图像,因此x、y、t 是x的函数。

  S4 通过三个连续参数矩阵A、B和C将它们互联,具体形式表现为以下两个方程(Mamba 论文中的 1a 和 1b):

  由于在实践中,我们一般都是处理离散数据比如文本,这就需要我们对 SSM 进行离散化,通过使用特殊的第四个参数Δ,将连续参数A、B和C转换为离散参数。

  离散化后,我们可以通过这两个方程(Mamba 论文中的 2a 和 2b)来表示 SSM:

  这些方程形成一个递归,情况类似于咱在 RNN 网络中看到的一样。在每个步骤t中,我们将前一个时间步 ht−1 的隐藏状态与当前输入 xt 相结合,以创建新的隐藏状态 ht。

  下图展示了它在预测句子中的下一个单词时是如何工作的(我们预测“and”跟在“My name is Jack”之后)。

  依据以此,我们本质上就可以使用 S4 作为递归神经网 RNN 来一次生成一个 token

  然而,S4 真正酷的地方在于,你也可以将它用作卷积神经网络 CNN。

  在上面的示例中,当我们扩展之前的离散方程来尝试计算 h3 时,会发生什么?

  为了简单起见,我们假设x−1=0。

  计算出 h3 后,我们可以将其代入 y3 的等式中来预测下一个单词:

  现在,请注意 y3 实际上可以计算为点积,其中右侧向量是我们的输入x:

  由于其中三个离散参数A、B和C都是常数,因此我们可以预先计算左侧向量并将其保存为卷积核。这为我们提供了一种使用卷积计算y的简单方法,如以下两个方程所示(Mamba 论文中的 3a 和 3b):

  划重点:这些循环和卷积形式(作者称之为“RNN 模式”和“CNN 模式”)在数学上是等效的。

  因此 S4 可以根据你需要执行的操作进行变形,同时输出没有任何差异。

  当然,CNN 模式更适合训练,RNN 模式更适合推理。

  第一个主要思想:可选性

  这部分我们讨论 Mamba 引入的第一个主要思想:可选性。让我们回想一下定义 S4 离散形式的两个方程:

  注意,在 S4 中,我们的离散参数 AB 和C是恒定的。然而,Mamba 使这些参数根据输入而变化。因此我们最终会得到这样的结果:

  Mamba 作者(Gu 和 Dao)认为,选择性或输入依赖性对于许多任务都很重要。

  而本文的科普作者则认为:因为 S4 没有选择性,所以它被迫以完全相同的方式处理输入的所有部分。

  然而,当我们面对一句话时,其中有些单词不可避免地比其他单词更重要。

  就比如 “I want to order a hamburger.”这句。

  如果没有选择性,S4 会花费相同的“精力”来处理每个单词:

  但如果是一个试图对这句话的意图进行分类的模型,它可能会想更多地“关注”order、hamburger,而不是 want、to。

  如下图所示,而通过使模型参数成为输入的函数,Mamba 就可以做到“专注于”输入中对于当前任务更重要的部分。

  然而,选择性给我们带来了一个问题。让我们回想一下之前计算的卷积核。

  在 S4 中,我们可以预先计算该内核、保存,并将其与输入x相乘。

  这很好,因为离散参数 AB 和C是恒定的。但同样,在 Mamba 中,这些矩阵会根据输入而变化!因此,我们无法预计算K,也无法使用 CNN 模式来训练我们的模型。如果我们想要选择性,我们得用 RNN 模式进行训练。方法是删除方程 3b 以获得“戏剧性的效果”。

  但这给 Mamba 的作者带来了一个问题:RNN 模式的训练速度非常慢。

  假如我们正在使用 1000 个 token 的序列训练我们的模型:

  CNN 本质上会计算其内核和输入向量之间的点积,并且可以并行执行这些计算。相比之下,RNN 需要按顺序更新其隐藏状态 1000 次。

  这便导致 Mamba 的作者提出了他们的第二个伟大思想。

  第二个主要思想:无需卷积的快速训练

  Mamba 可以在 RNN 模式下进行非常非常快速的训练。

  在某个时刻,它们的递归与扫描算法(也称为前缀和,prefix sum)非常相似。

  要计算前缀和,我们需要获取一个输入数组 [x1,x2,… ,xn] ,并返回一个输出数组,其中每个元素都是该项目及其之前项目的总和。

  换句话说,输出的第一个元素将为 x1 ,第二个元素将为[x1+[x2 ,依此类推。一个例子:

  现在我们画出 RNN 模式下更新 Mamba 隐藏状态的流程。

  等等……,如果我们必须形式化前缀和,我们可以将其写成以下等式:

  该方程形成一个递归:在每一步,我们通过将先前存储的值添加到当前输入来计算新值。现在,让我们再次看看更新之后 Mamba 隐藏状态的循环。

  这两个等式真的非常非常相似有么有!

  而最酷的地方又来了:虽然计算前缀和本质上看起来似乎是顺序的,但我们实际上拥有用于此任务的高效并行算法!

  在下图中,我们可以看到正在运行的并行前缀和算法,其中每条垂直线代表数组中的一项。

  花一点时间捋一下这个算法:

  选择任何垂直线,从顶部开始,然后向下移动,将每个加法追溯到数组的前几个项目。当你到达底部时,应该在行的左侧看到所有项目的总和。

  例如,在第一个元素添加到开头的第二个元素之后,数组的第三个元素在末尾接收了第二个元素的添加值。结果,当并行扫描完成时,第三个元素包含第一、第二和第三元素的总和。

  如果我们在没有并行性的单线程中运行该算法,则比仅按顺序将值相加所需的时间要长。但 GPU 拥有大量处理器,可以进行高度并行计算。因此,我们可以在大约O(logn) 时间内计算此前缀和(或扫描)操作!

  因此,Mamba 的作者意识到,如果他们想在 RNN 模式下高效训练,他们可能可以用并行扫描。

  但由于 PyTorch 目前没有扫描实现,Mamba 的作者自己编写了一个——但,结果并不好

  在上图中,大家可以看到他们基于 PyTorch 的扫描实现(绿色)总是慢于 FlashAttention-2(蓝色),FlashAttention-2 是可用“精确注意力”的最快实现。

  尽管当序列长度为 128000 个 token 时,扫描似乎在运行时赶上,但还是耗尽了内存。

  为了让 Mamba 变得实用,它需要更快。这让 Mamba 的作者看到了 Dao 之前关于FlashAttention的工作,从而解决了问题。

  由于篇幅所限,在此我们省略了原文中 FlashAttention 的原理介绍部分(Review: FlashAttention),感兴趣的朋友可以查看原博/FlashAttention 原论文,或者我们之前的一篇原理介绍文章。

  Back to Mamba

  还是基于上一张对比图。

  事实证明,如果在计算扫描时采用相同的内存感知平铺方法,则可以大大加快速度。

  通过这种优化,Mamba(红色)现在在所有序列长度上都比 FlashAttention-2(蓝色)更快。

  这些结果表明,就速度而言,Mamba 是实用的,其运行速度比最快的 Transformer 还要快。但它在语言建模方面有什么擅长的地方吗?

  Mamba 作者在涉及语言、基因组学和音频的许多序列建模任务上对 Mamba 进行了评估。

  结果看起来很酷:Mamba 在对人类基因组项目的 DNA 和钢琴音乐数据集的音频进行建模时建立了最先进的性能。

  然而,让很多人兴奋的是语言任务上的结果。许多关于 Mamba 的在线讨论都集中在下图中:

  我们可以看到,模型大小向右增加,语言建模性能则随着进一步向下而提高。

  这意味着最好的模型应该位于左侧:体积小(因此速度快),并且非常擅长建模语言。

  由于 Mamba 作者都是学者,搞不来数千个 GPU 来训练 GPT-4 大小的模型,因此实验是通过训练一堆较小的模型(大约 125M 到 1.3B 参数)来进行比较的。

  如上图所示,结果看起来非常有希望。与其他类似尺寸的模型相比,Mamba 似乎是最擅长建模语言的

  为什么被“二连拒”

  写到最后,本文作者再次表达了对 Mamba 被拒的惋惜:

我真的认为 Mamba 以一种非常独特和有趣的方式在语言建模上进行了创新。但很不幸,一些审稿人并不同意。

  从最新的驳回意见来看,其中一位审稿人的拒绝理由与“两个重大基准评估”有关。

  一是缺少 LRA(Long Range Arena)评估,公认的长序列建模基准。

  二是仅将困惑度评估作为主要评价指标不行,理由是低困惑度与生成性能不一定正相关。

  最终的总体意见是:再增加额外的实验。

  对此结果,有网友也再次评价道:

这只能说明一篇论文被会议接收与否与它对社区的价值贡献并不挂钩。因为前者很容易依赖于极少数人的判断。

  其实说到公认的好论文被顶会 pass 一事,Mamba 还真不是头一个。

  大约十年前,Word2vec也曾被 ICLR“丑拒”,然而去年,它还捧回了 NeurIPS 的时间检验奖。

  你觉得时间会为 Mamba“正名”吗?

  解读原文:

  https://jackcook.com/2024/02/23/mamba.html

  参考链接:

  [1]https://twitter.com/srush_nlp/status/1761094139544838275

  [2]https://twitter.com/volokulesho