Meta发表的将系统2模型蒸馏至系统1模型
发布时间:2024-11-14 23:59:04点击:
一、结论写在前面
论文标题:Distilling System 2 into System 1
论文链接:
LLMs在推理过程中可以额外消耗计算资源来生成中间思维,这有助于产生更好的最终响应。自思维链以来,已经提出了许多此类系统2技术,例如重述与响应(Rephrase and Respond )、系统2注意力(System 2 Attention)和分支-解决-合并(Branch-Solve-Merge)。
论文研究了自监督方法(self-supervised),将系统2技术的高质量输出“编译”(蒸馏,distill)回LLM生成中,而不需要中间推理token序列,因为这种推理已经被蒸馏到系统1中。
论文进行了跨4种不同System 2 LLM方法和5种不同任务的实验。论文发现,论文的方法能够在多种环境下将System 2推理蒸馏为System 1,有时甚至能超越System 2教师模型的效果。此外,这些预测现在以极低的计算成本生成。例如,论文在处理偏见观点或无关信息的任务(System 2注意力)、澄清和改进某些推理任务的响应(重述与回应)以及对LLM进行细粒度评估(分支-解决-合并)方面看到了成功的蒸馏。
然而,论文也表明并非所有任务都能蒸馏到System 1,特别是需要链式思维的复杂数学推理任务。这一点在人类中也得到了体现,有些任务没有刻意的System 2推理是无法执行的。
二、论文的简单介绍
2.1 论文的背景
人类 System 1System 1推理被描述为能够识别模式、快速做出判断以及理解简单或熟悉的符号。例如,它用于识别常见的交通标志、识别人脸或关联基本符号与特定情绪或想法。
人类 System 2对于复杂的问题解决或例如抽象符号(如代数方程或逻辑陈述)的操作,System 2推理被认为是必要的。在心理学中,自动性概念描述了行为变得如此熟练以至于可以在几乎没有意识思考的情况下执行,例如驾驶熟悉的路线。一般来说,人类被认为使用程序记忆将特定任务整合到记忆中,通过实践学习,以便之后无需意识就能执行。无意识能力概念被归类为学习的后期阶段。最初,一个人认识到自己的无能,并有意学习一项技能,直到获得有意识的能力。最终目标是在无需意识思考的情况下使用它,这时它被称为,用通俗的话说,“第二天性”。
模型 System 1论文将直接输出响应而不产生中间输出的神经网络称为系统1模型。尽管如此,这类网络在其层中仍可计算中间的潜在表征,然后输出响应。由于这些状态以向量形式表示,它们通常编码分布式知识而非离散决策,并且难以直接处理复杂的符号推理任务,这与人类系统1推理存在的问题类似。尽管如此,许多任务可以直接通过这种方式成功解决,无需中间生成(Radford et al., 2019)。
模型 System 2同一个无法执行复杂多步骤计算的语言模型,在要求其通过少样本提示或监督训练生成中间步骤到“草稿板”上时,能够完成这些任务。链式思维推理已被证明可以通过零样本提示、监督训练或少量样本方法从大型语言模型中引发。大型语言模型的预训练使得这种推理能够融入模型中,因为训练语料库中包含了人类编写的离散符号(文本)的推理步骤。这类系统2模型方法输出离散的token,有利于进行连续正确的逻辑推理步骤——但显然,如果推理生成错误,则存在缺点。错误的离散决策难以恢复,与可能更容易建模分布的潜在向量推理不同。
生成中间思考过程允许模型进行推理和规划,以成功完成任务或响应指令。论文将这种深思熟虑的思考称为系统2推理,这一概念源自Sloman(1996)和Kahneman(2011)对人类的描述,后来也被应用于人工智能模型。在系统2推理中,消耗大量认知资源来处理复杂问题和重要决策。因此,在标准的大型语言模型(LLMs)中,论文将系统1定义为直接应用Transformer来根据输入生成响应,而不生成中间token。论文将系统2定义为任何生成中间token的方法,包括执行搜索或多次提示,然后最终生成响应的方法。
目前已提出了一系列这样的系统2技术,其中包括思维链(Chain-of-Thought)、思维树(Tree-of-Thoughts)、思维图(Graph-of-Thoughts)、分支-解决-合并(Branch-Solve-Merge)、系统2注意力(System 2 Attention)、重述和回应(Rephrase and Respond)等等。许多这些方法通过显式推理被证明能产生更准确的结果,但通常会以更高的推理成本和响应延迟为代价。由于后者的原因,许多这些方法并未在生产系统中使用,生产系统主要使用系统1生成。
图1:系统2蒸馏概览。通过在未token数据上运行系统2方法(如分支-求解-合并(BSM))收集过滤后的训练样本,这些方法利用额外计算产生更高质量的输出。然后将这些目标蒸馏到标准(系统1)语言模型中
对于人类而言,心理学中将技能从有意识(系统2)转移到自动(系统1)的过程被称为自动性,并利用程序性记忆。例如,首次驾车上班时,人们可能会耗费大量意识努力进行规划和决策以到达目的地。经过多次重复这条路线后,驾驶过程便“编译”为潜意识(Charlton and Starkey, 2013)。同样,像打网球这样的运动可以变得“习以为常”。
论文探索了一种类似的技术应用于AI模型。论文的方法以无监督方式进行这种编译,论文称之为系统2蒸馏,给定一组未token样本。对于每个样本,论文应用给定的系统2方法,然后以无监督方式衡量预测质量。例如,对于具有唯一答案的任务,论文采用自一致性(self-consistency),多次采样。对于系统2足够一致的样本,论文假设此结果应被蒸馏,并将其添加到蒸馏池中。随后,论文微调系统1以匹配系统2方法在收集的样本池上的预测,但不生成中间步骤。图1展示了将系统2蒸馏为系统1的整体过程。
2.2 将系统2蒸馏至系统1
2.2.1 设置:系统1与系统2模型
给定输入 论文x论文,本工作考虑单一模型的情景,即大型语言模型(LLM),该模型具备两种响应模式:
(i) 系统1:直接生成输出 论文y论文。这是通过前向传播底层自回归神经网络(Transformer)的各层以生成输出token来实现的。
(ii) 系统2:论文将系统2模型定义为利用底层Transformer在生成最终响应token之前生成任意类型的中间输出token 论文z论文 的方法。这可能包括多次调用(提示)。
更正式地,论文将一个System 2模型S视为一个函数,该函数接受一个LLM 和输入x,并可能多次调用LLM以使用特定算法生成中间token,然后返回一个输出论文y:
System 2方法可能涉及多个提示、分支、迭代和搜索,同时利用LLM生成中间结果以进行进一步处理。相比之下,一个System 1模型仅考虑原始输入x,并直接调用LLM生成输出y:
有许多现有的System 2模型实例。思维链提示仅需要单个LLM提示,但仍输出中间生成内容,然后给出最终响应,通常用于数学和其他推理任务)。
诸如System 2 Attention和Rephrase and Respond(等方法需要两次调用LLM,在前者中,第一次调用用于关注上下文并消除偏见,而在后者中用于扩展问题。第二次调用则用于根据中间生成内容最终回答问题。某些方法更为复杂,例如Branch-Solve-Merge(,它通过LLM生成计划,该计划分支成多个LLM调用,直到最终阶段合并结果。
论文将对上述四种方法进行实验,但还有许多其他System 2方法,例如Tree-of-Thoughts、Graph-of-Thoughts等。
2.2.2 方法:系统2蒸馏
许多系统2方法本质上在推理时由于多次提示调用和生成中间token而显著较慢。系统2蒸馏的目标是将所有推理从S_II蒸馏回S_I,以便语言模型的直接输出p_θ( x)得到改进。论文假设模型可以访问未token的输入t,从中它可以学习,类似于人类如何在无监督的情况下学习程序记忆。对于基于语言的任务,通常可以访问遵循指令的提示(输入),因为它们可以由人类收集,例如发布的1M Wild-Chat交互,其中提供了输入但正确标签未知。因此,这是一个现实的设置。
所提出方法的第一步是使用系统2模型在未token的输入t上生成响应:
这些响应可以直接用作微调系统1模型的系统2蒸馏目标。然而,它们受到噪声的影响:其中一些响应可能是高质量的,而其他可能是低质量或不正确的。对于涉及短响应且通常具有唯一正确(但未知)答案的短形式QA和推理任务,论文因此考虑一个无监督的筛选步骤,以尝试提高训练数据质量。论文考虑两种变体,两者都依赖于一致性标准:
•输出自一致性:论文总共采样S_II(x^ i ; p_θ) N次,并接受多数投票的响应;如果没有多数胜出者,论文丢弃该示例。
•输入扰动下的自一致性:论文以输出不应改变的方式扰动输入w,例如改变提示中多项选择项的顺序,并为每个扰动计算S_I;如果输出不一致,论文丢弃该示例。
随后,论文得到合成数据集(X_S_II , Y_S_II),其中 论文X_S_II是X的过滤子集,目标为Y_S_II)。最后一步是使用这个蒸馏的训练集对具有参数pθ的大型语言模型(LLM)进行有监督的微调。论文通常从当前状态pθ初始化模型,并继续使用新数据集进行训练。
微调后,论文获得一个 LLM p_θ,这是一个系统1模型,预计其输出和性能提升与评估的系统2模型相似。
2.3 实验
2.3.1 训练与评估设置
论文使用 Llama-2-70B-chat作为所有实验的基础模型。论文需要一个足够强大的基础模型,使其能作为系统2模型表现出色,同时具有可微调的开源权重,因此选择了此模型。论文考虑了几种系统2方法,包括重述与回应(RaR)、系统2注意力(S2A)、分支-解决-合并(BSM)和思维链(CoT),重点关注每种方法已展示出强大性能的任务。对于系统1,论文使用指令调优的基础模型进行零样本推理,作为标准基线。论文报告每个任务的特定指标,以及“#Tokens”指标,该指标衡量评估集中每个输入生成的平均token数量。对于系统2方法,这包括中间token生成和最终输出token生成。
2.3.2 重述与回应蒸馏(Rephrase and Respond Distillation)
重述与回应(RaR)是一种系统2方法,首先提示语言模型对原始问题进行进一步阐述的重述,然后基于重述的问题生成回应,旨在提供更优质的输出。作者介绍了两种方法,1步RaR和2步RaR,后者涉及两个单独的提示,而不是像前者那样的组合提示,具体提示见附录A.1。他们发现2步RaR在几个对基线LLM具有挑战性的推理任务上显著提高了性能。论文考虑了原文中表现良好的两个任务:最后一个字母连接任务和硬币翻转推理。然后评估是否可能蒸馏这种系统2方法。
蒸馏数据集论文为RaR构建了系统2蒸馏数据集,利用输出的自一致性。对于每个输入,论文对最后一个字母任务进行八次采样迭代,并对硬币翻转任务的每个阶段进行八次采样迭代。然后,论文通过多数表决来确定最终输出。
2.3.2.1 最后一个字母拼接任务(Last letter Concatenation Task)
此任务侧重于符号推理,要求模型拼接给定单词的最后一个字母。例如,指令:“取Edgar Bob中单词的最后一个字母并拼接它们。”正如Deng等人(2023a)所示,此任务从RaR方法的应用中获益显著。论文通过随机选择1200个独特的英语单词来编译数据集。利用这些单词,论文分别为训练、验证和测试构建了200个样本。
结果总体结果见表1。基准系统1模型(Llama-2-70B-chat)达到30.0%的准确率,被1步和2步RaR的系统2方法(分别为39.5%和44.5%)超越。通过论文的无监督技术将2步RaR方法蒸馏回系统1 Llama-2-70B-chat模型,论文实现了惊人的98.0%准确率。与零样本聊天模型相比,该模型能有效学习如何解决此任务。重述并回应的蒸馏有效继承了系统2和系统1的优势。它在保持系统2的准确性优势的同时,推理成本与系统1相当(见生成token数量)。
分析与消融实验为了评估论文利用输出自一致性的无监督筛选步骤的有效性和必要性,论文通过创建一个不应用自一致性过滤器的蒸馏数据集进行了消融研究。当论文在这个未经过滤的数据集上使用相同的设置对System 2模型进行了蒸馏,其精确匹配准确率达到了87.5%(过滤版本为98%)。这一比较突显了一致性过滤的关键作用。尽管如此,在两种情况下,构建训练数据确实比零样本性能有所提升。论文还尝试使用相同的过滤技术对System 1预测进行蒸馏,结果准确率较低,为69.5%。
表1:重述并回应的系统2蒸馏:硬币翻转和最后一个字母拼接任务。论文报告精确匹配(EM)测试准确率和生成(中间和输出)token数量
2.3.2.2 硬币翻转推理任务
这一符号推理任务在研究中经常被测试,包括在Wei等人(2022)和Deng等人(2023a)的研究中。它涉及从已知初始位置开始,经过一系列自然语言描述的翻转后,确定硬币的最终面(正面或反面),例如“一枚硬币正面朝上。Roxas没有翻转硬币。Schneiderman没有翻转硬币。硬币还是正面朝上吗?”Deng等人(2023a)表明,即使是强大的语言模型也无法成功完成这一任务,而应用RaR方法则能提高它们的性能。该任务有20k个训练示例(无标签,用于无监督学习),3.33k个验证示例和1.33k个测试示例。
结果总体结果见表1。Llama-2-70B-chat(零样本)在该任务上的成功率为56.1%,而1-Step和2-Step RaR的成功率分别为58.59%和77.2%。因此,论文仅在2-Step方法中看到了显著的改进。通过论文的无监督技术将2-Step RaR蒸馏回System 1 Llama-2-70B-chat,成功率为75.69%。因此,论文发现论文的蒸馏System 2模型提供了与System 2(2 Step RaR)相当的性能,但无需执行LLM程序。
表2:System 2注意力蒸馏:TriviaQA任务,报告有偏和无偏评估集的准确率
分析与消融实验Deng等(2023a)的RaR方法包含了提示工程技巧,例如在原始查询后附加"Flip意味着反转。回答是或否问题"等短语,这已被证明可以提高模型性能。遵循他们的方法,论文使用不同的提示评估了模型性能,见表6。当使用"Flip意味着反转"和"Flip意味着反转。回答是或否问题"等提示测试Llama-2-70B-chat模型(系统1)时,论文观察到性能显著提升,从56.11%提高到66.84%。这突显了提示选择在优化系统1模型性能中的关键作用。然而,这种对提示工程的依赖也代表了一个局限性,需要额外的人力投入。
论文还尝试对系统1模型进行蒸馏,但得到了较差的性能。在这种情况下,论文同样观察到不同提示下性能的波动。相比之下,蒸馏后的系统2模型在各种提示下表现出一致的性能,对提示变化的敏感度较低。这种一致性表明,对于蒸馏后的系统2模型,可能不需要进行大量的提示工程。
2.3.3 系统 2 注意力蒸馏
Weston和 Sukhbaatar 在 2023 年提出了系统 2 注意力(S2A),这是一种有助于减少模型推理缺陷的方法,如依赖输入中的偏见信息或关注无关上下文。S2A 是一种两阶段推理方法,第一阶段重写输入,使其不包含如偏见或无关上下文等不期望的信息,第二阶段关注重写后的较短上下文(与 Rak 扩展上下文相反),参见图 6。在本研究中,论文验证了将 S2A 蒸馏到系统 1 的可行性。特别地,论文关注了SycophancyEval问答任务(Sharma 等人,2023),该任务的输入中包含已知会损害大语言模型(LLM)性能的偏见信息。论文使用了来自 SycophancyEval 的 6668 个示例作为未token训练数据,以及 个示例用于评估,后者被分为偏见输入(350 个)和无偏见输入(50 个)。
蒸馏数据论文使用通用自一致性(USC)(Chen et al., 2023)来筛选高质量的目标。具体而言,论文采样20个生成结果,然后利用Llama-70B-chat模型配合USC提示(如图12所示)来组合一个自一致性(多数)的最终答案,该答案作为蒸馏目标。
结果结果如表2所示,报告了3个随机种子的平均准确率。基线(系统1)LLM在偏见部分的准确率较低,正如预期,因为其容易受到偏见输入的影响。S2A显著提升了偏见输入的性能。系统2蒸馏显示出与系统2方法相似的强劲性能。然而,与基线和S2A模型相比,平均使用的token数量有显著减少。这是因为偏见输入往往使基线LLM生成更多的输出token,而S2A还需要生成中间token。图11展示了一个代表性示例。最后,论文通过报告不使用USC的结果(最后一行),显示后者提供的结果较差,从而表明使用USC进行蒸馏对整体结果的重要性。这突出了在微调过程中使用的蒸馏数据质量的重要性。
2.3.4 分支-解决-合并蒸馏
分支-解决-合并(BSM)(Saha et al., 2023)由三个模块组成:分支、解决和合并。这些模块协同工作,将任务分解为多个并行子任务,每个子任务由特定提示引导。BSM在LLM作为评判者的情境中已被证明有效,如图14所示。该方法首先提示语言模型列出针对特定用户查询定制的评估指标(分支)。随后,LLM被查询以基于每个指标独立并行地评估响应(解决)。最后,来自每个分支的分数被平均以得出一个全面的评估决策(合并)。值得注意的是,这种方法的推理成本是传统(系统1)LLM评估方法的5-6倍,使其实用性大打折扣。论文评估了蒸馏BSM的可行性,旨在保留其优势的同时降低计算成本。
表3 系统 2 分支-解决-合并 (BSM) 的蒸馏:Open Assistant (OASST2) 和 MT-bench 对 LLM 作为判断者的评估。系统 2 BSM 的蒸馏优于 BSM 本身,甚至优于 GPT4 作为判断者,尽管使用的是 Llama-2-70B-chat。蒸馏后的 BSM 具有更高的人类一致性(一致性),更少的位置偏差,并且不一致样本的百分比为 9.1%
蒸馏数据遵循 Yuan 等人 (2024) 和 Li 等人 (2023b) 的方法,论文使用了 Open Assistant>评估论文在两个流行的基准上评估论文的模型,即 OASST2 验证集和 MT-bench (Zheng 等人, 2024)。OASST2 验证集包含 273 个样本,仅限于第一轮和英语语言。对响应对的评估在原始顺序和交换顺序下进行。由于论文的蒸馏模型是在 OASST2 训练集上训练的,OASST2 验证集作为分布内评估集,而 MT-bench 则更具分布外特性。MT-bench 是一个流行的基准,评估 LLM 作为有用 AI 助手对话时对其他 LLM 响应的判断。它包含来自 8 个不同领域的指令,例如写作、推理、数学、编码等。
遵循 Zheng 等人 (2024) 的方法,论文评估了模型投票与人类专家投票之间的一致性。LLM 作为判断者的一个已知局限是位置偏差,即语言模型 (LLM) 倾向于偏好某些位置而非其他位置。这种偏差在改变评估提示中响应的位置时,常常导致模型做出不同的决策。为了量化这一点,论文不仅测量一致性,还计算不一致样本的百分比以评估位置偏差。
OASST2评估结果 表3提供了在OASST2数据集上的结果。与基线(系统1)大型语言模型相比,思维链(CoT)方法通过提高一致性和降低不一致率来改善性能(参见附录中的提示)。虽然BSM表现优于CoT,但这是以增加推理时间(#To-kens)为代价的。值得注意的是,论文蒸馏的系统2 BSM模型仅需生成四个token,仍然优于CoT和BSM。此外,论文基于Llama-2-70B-chat的蒸馏模型超过了GPT-4-0125-preview,实现了更高的人类一致性和更大的连贯性。
MT-Bench评估结果表3也提供了在MT-bench上的结果,该测试作为分布外测试。结果与OASST2评估的结果相呼应。思维链(CoT)和BSM都提高了模型性能,但代价是显著增加的推理成本。论文的蒸馏BSM模型不仅实现了更高的人类一致性和更低的不一致率,而且需要的计算资源更少。尽管论文的模型在一致性上略逊于最先进的GPT-4-0125-preview模型,但它仅基于Llama-2-70B-chat在OASST2上的未标注数据进行训练。尽管如此,它在连贯性上更优,且在输出token方面推理成本低廉。
图2:MT-bench上LM评判与人类偏好之间的一致性,按评估类别划分
表3:GSM8k测试集准确率。多数投票中的投票数k表示为收集预测答案的投票而采样的候选数量。在这种情况下,系统2的CoT蒸馏效果不佳
按类别分析在此,论文进一步按类别分析MT-Bench结果中的一致性。图2展示了按类别的一致性。论文观察到,与基础模型(Llama-2-70B-Chat)相比,CoT在所有类别上提高了一致性。BSM优于CoT,而论文的蒸馏BSM甚至优于BSM。尽管蒸馏BSM在所有类别上相较于基线取得了优越的性能,但在推理、编码和提取方面仍落后于GPT-4-0125-preview。然而,在写作、数学和STEM方面,它超过了GPT-4-0125-preview。
2.3.5 思维链蒸馏
思维链(CoT)已被证明是提高LLM推理能力的有效方法,例如解决研究生数学问题。LLM生成中间token,这些token是推理(思维)的步骤(链),然后产生最终答案。论文考虑了该方法的两个变体:(i)少样本CoT,即从训练集中提供多个[问题,CoT,答案]示例作为上下文,随后是问题;(ii)零样本,即在提示中除了问题外还添加了“一步一步”思考的明确指令,详见附录图10。
蒸馏数据论文使用CoT为GSM8k训练集中的问题(论文认为这些是无标签的,由Cobbe等人,2021年提出)生成答案,采用K=10的多数投票方法。由此产生的蒸馏训练集包含7461个[问题, 答案]对,即不包含任何中间推理步骤。为了分析目的计算的自监督目标准确率为56.81%。
评估论文在GSM8k测试集上使用不同K值的多数投票方法计算并报告评估准确率。与之前的实验类似,论文报告每种方法预测的平均token数。请注意,论文在进行多数投票时计算所有生成token的平均值,以观察K值的增加如何影响推理成本。论文考虑了几个基线:系统1和系统2(CoT)方法在零样本或8样本输入上下文中进行评估。需要注意的是,系统2在8样本情况下意味着在少量样本输入中提供了CoT,而系统1则意味着少量样本示例包含问题和答案,但没有CoT。
结果评估结果如表3所示。首先,正如预期,使用CoT方法带来了改进:将其作为少样本上下文的一部分或作为提示模板中的指令的一部分时,这种方法有所帮助。这些改进伴随着推理成本的增加:与System 1方法相比,使用CoT方法预测的序列长度显著增加。其次,论文的System 2蒸馏方法在各种解码超参数下表现不佳。GSM8k任务(数学问题)所需的推理类型与论文在此工作中考虑的其他任务截然不同。这突显了System 2蒸馏的非平凡性:所提出的蒸馏算法在许多情况下有效,但并非总是如此。这为未来的研究留下了空间,以阐明在何种具体情况下应用蒸馏,以及何时不应应用,或许可以采用类似于人类的方法。
本文转载自,作者: