「专业智能体指导」让小模型学会数学推理!微调Mistral-7B实现86.81%准确率

  新智元报道

  编辑:LRS

  小模型也能解锁数学能力,无需多模型集成,7B 模型在 GSM 8 k 数据集上性能超越 70B!

  对于小型语言模型(SLM)来说,数学应用题求解是一项很复杂的任务。

  比如之前有研究结果显示,在 GSM 8K 基准测试中实现 80% 以上准确度所需的最小模型尺寸为 340 亿个参数。

  为了在较小的模型上达到这种性能水平,研究人员经常训练 SLM 来生成 Python 代码或使用外部工具作为辅助,以避免计算错误。

  或是基于集成(ensembling)技术,将 100 多个模型生成的输出组合在一起,以获得更准确的结果,最终结果的选择需要通过共识、多数表决或与 SLM 结合使用的单独的验证器模型来完成,可以显著提升准确率(Phi-GSM 使用 top-48 将性能从 68.2 提升到 81.5),不过代价是由于多次调用模型导致的成本显著增加。

  最近,微软的研究人员提出了一个基于 Mistral-7B、70 亿参数量的小型语言模型 Orca-Math,它在 GSM 8 k 上实现了 86.81%,不需要调用多个模型进行集成或使用验证器、代码执行或任何其他外部工具。

  论文链接:https://arxiv.org/abs/2402.14830

  Orca-Math 的关键特性为:

  1. 使用多个智能体(agent)创建出 20 万个数学问题的高质量合成数据集,其中智能体合作创建数据;

  2. 迭代学习技术,使 SLM 能够练习解决问题,接收对其解决方案的反馈,并从包含 SLM 解决方案和反馈的偏好数据中学习。

  当单独使用有监督微调训练时,Orca-Math 在 GSM 8 k pass@1 指标上达到 81.50%。通过迭代偏好学习,Orca-Math 实现了 86.81% 的 pass@1

  Orca-Math 超越了 LLAMA-2- 70B,WizardMath-70B,Gemini-Pro,ChatGPT-3.5 等更大型号的性能,在使用小得多的数据(数十万对数百万问题)时也显著优于其他较小的模型。

  数据集构造

  种子集合

  首先从现有的开源数据集中收集数学单词问题样本,即 NumGLUE、AddSub、ALGES、ASDiv、DRAW、GSM8k、MATHQA、MultiArith、SingeOP、SingleEQ 和 SVAMP。

  研究人员从 Lila 的训练和验证分裂中收集问题,以构建种子集,总共收集了 36217 个问题。

  智能体 - ask me anything

  通过从种子集中的问题创建多个单词问题来扩展种子集,利用后续提示来创建问题。

  智能体总共生成了 120445 个新问题,但所有生成的问题都表现出与种子词问题相似的叙述方式,具体解决方案是使用 GPT4-Trubo 生成的。

  智能体 - Suggester & Editor

  通过解决具有挑战性的问题进一步扩大种子集合。

  为了实现这一点,研究人员引入了两个新的智能体,即 Suggester 和 Editor,可以协同工作以创建一个面向预定义目标的数据集:修改现有问题以增加其难度。

  Suggester 研究一个特定的问题,并提出了几种在不产生实际问题的情况下提高其复杂性的方法。

  Editor 采用原始单词问题和 Suggester 的建议,生成一个更新的、更具挑战性的问题,迭代过程可以发生在多个回合中,每一回合都会进一步增加先前生成的问题的复杂性。

  眼人员利用 AutoGen 框架来实现多智能体工作流。

  对每个问题进行两轮迭代,并过滤 GPT4-Turbo 生成的答案超过 1800 个字符的问题,最终收集了 37157 个问题。

  训练

  有监督微调实验(第一次迭代)

  在 Orca-Math-200K 数据集上对 Mistral-7B 进行了微调,没有使用 packing,下面为具体的指令格式。

  损失函数只基于答案 token 来计算。

  正负信号的迭代学习

  数据集构建(第二次迭代)

  为了为每个问题生成额外的正样本和负样本,研究人员从第一次迭代的 SFT 调优模型中采样四个回复。

  具体来说,使用 top_p=0.95 和温度=0.7,过程产生了一个数据集,其中 200000 个问题中的每个问题都有一个 GPT4-Turbo 生成的解决方案和四个学生生成的解决方法。

  使用基于 GPT4 的精确匹配中定义的提示来评估教师(GPT4-Turbo)的答案和学生的答案之间的一致性。

  对于学生生成的答案与老师的答案不匹配的所有解决方案,将其标记为负样本。

  数据集构建(第三次迭代)

  为了从正反馈和负反馈中学习,研究人员评估了两种算法的性能:直接偏好优化(DPO)和 Kahneman-Tversky 优化(KTO),还探索了 KTO 的功能,其区别在于只需要二进制「是」或「否」的回复来评估输出的质量。

  评估方法

  研究人员使用精确匹配作为评估指标。

  给定一个模型生成的答案,提示 GPT-4 来提取最终的简短答案,并将其与金标准中的简短答案进行匹配,即基于 GPT4 的精确匹配(GPT4-based-Exact-Match)。

  实验结果

  研究人员测试了模型在包含 1319 个单词问题的 GSM8k 测试集上几个训练过程的性能,对 Mistral-7B 模型进行了三次迭代的微调

  在第一次迭代中,使用有监督微调来获得 M1;

  第二次迭代中,对比了 SFT、DPO 和 KTO,其中 KTO 训练的模型在这一组中表现更好,获得 M2 后,并使用 M2 生成迭代#3 的数据集;

  第三次迭代中,对比了 DPO 和 KTO 方法,使用 M2 作为模型起点。

  研究人员还将这些模型与 Orca-Math-200K 数据集上经过三个 epoch 的 SFT 训练进行了对比。

  消融实验

  Model Generated Positives

  通过将限制为仅包含教师生成的解决方案来研究影响模型生成的正向因素(positives),换言之,研究人员移除在为迭代#2 创建数据集时模型生成的所有

  结果显示,不管训练算法如何,都会看到显著的性能下降。

  Synthetic Negatives

  数据集的创建包括在 M1 或 M2 生成的所有四个回复都是 positive 的情况下的合成负样本(negative creation)。

  通过忽略问题 qi 来研究这些合成负样本的影响,结果将第二次迭代的问题数量减少了约 80k,将第三次迭代的问题数量增加了约 104k

  除 GSM8k 外的数学基准

  研究人员还使用 Orca Math 其他几个单词问题数据集上进行了实验,并且为了便于评估,最终选择了问题答案都是单个数字的数据集。

  评估指标为基于 GPT4 的精确匹配度量,并使用贪婪解码生成模型回复。

  沾染检查(Contamination Check)

  为了确保实验的公正性,研究人员在文中表示:在训练过程中,从未使用 GSM8K 或任何其他数据集的测试分割集,也从未将其用作合成问题生成的种子。

  尽管如此,研究人员还是采用以下方法来检测任何潜在的文本沾染(text contamination)问题:

  1. 对文本进行预处理,包括将所有字符转换为小写、删除标点符号、对文本进行分词,以及删除常见的英语停止词,以确保数据的一致性。

  2. 使用逆文档频率(TF-IDF)方法对文本语料库进行矢量化,并确定测试集和训练集之间的余弦相似性,从中为每个测试查询选择前k个(k=10)最相似的问题。

  3. 通过计算在预设阈值 0.5 以上具有最高n-gram 重叠的试题数量及其相应的训练集匹配来评估文本污染的程度。

  研究人员使用 Jaccard 相似度来计算文本对之间的n-gram 重叠,并且为了进行严格的污染检查,n设置为1。

  需要注意的是,当使用 Jaccard 相似性测量时,n-gram 重叠是n的非递增函数。

  4. 在执行算法时,确定表现出显著的n-gram 重叠的试题数量为8,因此根据定义的阈值,表明测试集中的文本污染可以忽略不计。

  当将训练集限制为仅包含种子问题时,表现出显著n-gram 重叠的测试问题的数量为7;并且在n≥2 的情况下,表现出显著的n-gram 重叠的试题数为零。

  参考资料:

  https://arxiv.org/abs/2402.14830