Hungry Hungry Hippos: Towards Language Modeling with State Space Models
ICLR2023,notable-top-25%
T Dao, D Y. Fu, K K. Saab, A W. Thomas, A Rudra, C Ré
[Stanford University]
https://arxiv.org/abs/2212.14052
https://github.com/HazyResearch/H3
https://openreview.net/forum?id=COZDy0WYGg
摘要:状态空间模型(SSM)在某些模态中显示了最先进的序列建模性能,但在语言建模中表现不佳。此外,尽管在序列长度上几乎线性缩放,而不是二次缩放,但由于硬件利用率低,SSM仍然比Transformer慢。在本文中,我们在理解语言建模中SSM和注意力之间的表达能力差距以及减少SSM和注意之间的硬件障碍方面取得了进展。首先,我们使用合成语言建模任务来理解SSM和注意力之间的差距。我们发现,现有的SSM在两种能力上存在困难:调用序列中较早的令牌和比较整个序列中的令牌。为了理解对语言建模的影响,我们提出了一个新的SSM层H3,它专门为这些能力设计。在合成语言上,H3与注意立相匹配,并且在OpenWebText上的Transformer的0.4 PPL范围内。此外,保留两个注意力层的H3-attention混合模型(125M参数)出人意料地超过OpenWebText上的Transformer by1.0 PPL。接下来,为了提高在现代硬件上训练SSM的效率,我们提出了FlashConv。FlashConv使用融合块FFT算法来提高up to 8K序列的效率,并引入了一种新的状态传递算法,该算法利用SSM的递归属性来扩展到更长的序列。FlashConv在long-range arena基准测试上的速度提高了2倍,并且允许混合语言模型生成文本的速度比Transformer快2.4倍。使用FlashConv,我们将混合H3注意力语言模型缩放到Pile上的2.7B参数,并找到了有希望的初始结果,实现了比Transformer更低的困惑,在SuperGLUE基准测试中的大多数任务上,在零和少样本学习方面优于Transformer。
1 引言
状态空间模型(SSM)在从时间序列分析[25]到音频生成[22]的领域中实现了最先进的序列建模性能。然而,它们在语言建模方面的表现还没有达到Transformer的水平,常常在困惑中落后Transformer多个点[25]。一个自然的问题是,这种性能差距是否是由于固有的归纳偏置和注意力能力[17,49]所致,或者这是否是花费了大量组织资源来训练和调整大型基于注意力的语言模型[10,32,66],以及专门的注意力硬件支持的结果,范围从张量核[45]到Transformer芯片[34、48]。
我们在本文中采取第一步来回答这些问题。首先,我们使用合成语言建模任务来表明SSM和注意力之间存在表达能力差距。利用我们的见解,我们设计了一个新的SSM层,它几乎与语言建模中的注意力相匹配。第二,我们为SSM提出了更好的硬件感知算法,使其能够利用现代加速器并运行得比注意力更快。
理解表达能力差距。 为了理解SSM和注意力之间的差距,我们借鉴了合成语言建模任务,这些任务在Transformer[49]中被提出作为上下文学习的机制基础。我们发现现有的SSM难以对这些合成语言进行建模。为了探究这些技能对语言建模有多重要,我们提出了H3(Hungry Hungry Hippo),这是一个新的基于SSM的层,旨在解决这些语言建模任务。H3堆叠两个SSM,其输出和输入投影之间具有乘法交互作用(multiplicative interactions)。SSM允许H3保留a log of tokens(以便稍后调用),而乘法交互允许在序列中进行比较。
H3与合成语言的注意度相匹配,几乎填补了与Transformers在语言建模方面的差距,在OpenWebText上的Transformers的0.4困惑范围内(相比之下,现有SSM的3.4 ppl,即使是专门为语言建模设计的SSM[42])。此外,一个保留了两个注意力层的简单混合H3注意力模型令人惊讶地超过OpenWebText上的Transformer1.0困惑。为了进一步评估H3对语言建模的影响,我们使用GPT-3[7]的超参数,在Pile[21]上训练125M-、355M-和1.3B-和2.7B-参数的混合H3注意力语言模型。这些混合模型在困惑程度上优于相同大小的基于Transformer的语言模型,并且在SuperGLUE基准测试中的大多数任务中,在零和少样本学习中与它们匹配或优于它们。由于这些混合模型中的SSM层允许recurrent view,因此它们也可以执行比Transformer快2.4倍的推理。
缩放SSM。 接下来,我们提高了SSM在现代硬件上的效率,以减少注意力和SSM之间的硬件障碍。SSM在序列长度上几乎呈线性扩展,而不是像注意一样呈二次方扩展,但由于硬件利用率低,SSM在现代硬件上的运行速度仍然较慢。为了弥补这一差距,我们提出了FlashConv,这是一种计算SSM的分层算法,灵感来自IO感知注意[15]。技术挑战是SSM需要对输入序列进行基于FFT的卷积,这需要FFT、逐点乘法和逆FFT。当在cuFFT[47]中实现时,此操作会导致昂贵的GPU内存读取/写入,并且无法利用现代硬件上可用的专用矩阵乘法单元为了使用专门的矩阵乘法单元(原文注释),我们求助于将FFT分割成块并使用一系列矩阵乘法进行计算的经典技术。结合核融合,这种“分块”FFT解决方案提高了硬件效率,但前提是序列长度可以适合GPU SRAM(片上存储器,类似于CPU上的L1缓存)-在现代A100上,序列长度可达8K。
(原文注释:An A100 GPU has a maximum of 312 TFLOPs/s of FP16 with tensor cores, but only 20 TFLOPs/s of FP32 (and 40 TFLOPs/s of FP16) without tensor cores [46]. This trend started with the V100 GPUs [45] and has continued with the H100 GPUs [48])。
为了扩展到大于8K的序列,我们提出了一种状态传递算法(图1右侧),专门用于SSM。关键的见解是,只要我们跟踪额外的状态向量,就可以使用SSM的递归属性来以chunks的方式来处理输入。状态传递算法将输入分割成可装入GPU SRAM的最大块,使用分块FFT高效地计算基于FFT的卷积,并更新中间状态以开始下一块。使用这种状态传递算法,FlashConv可以将SSM缩放到任何序列长度,甚至比GPU SRAM一次可以容纳的长度还要长,同时保持近乎线性的计算复杂性。FlashConv使用S4[25]在long range arena上创下了最先进的速度,比Transformer高5.8倍,比之前的S4型号高2倍。FlashConv训练H3的速度是长序列注意力的4-8倍,是扩展到十亿参数模型的关键组件2。
2 背景
我们介绍了状态空间模型和线性注意力的一些背景,这启发了我们的H3层。
2.1 状态空间模型
连续时间状态空间表示[6]定义了从输入信号(作为时间的函数)通过一个状态变量到输出信号的线性映射,对于一些矩阵状态变量、、、,通过以下微分方程:、。
类似地,离散时间状态空间表示定义了从离散输入信号()通过一个状态变量到离散输出信号的线性映射:
状态空间模型(SSM)使用这些表示作为深度学习流水线中的一层,其中矩阵、、、是从数据中学习的(例如,使用基于梯度的优化)。一个通常有个并行的SSM,每个SSM对应一个隐藏维度。为了保存序列历史,HiPPO[24]在正交多项式的基础上投影历史,这意味着具有将、矩阵初始化为某些特殊矩阵的SSM。
SSM的这种循环形式允许有效的推断(即生成):为了生成下一个时间步长的输出,只需要当前时间步长的状态,而不需要整个输入历史。此外,SSM可以自由推断出比训练期间更长的序列。
SSM作为卷积。 为了有效地训练,给定整个输入序列,输出序列也可以写成输入与滤波器的卷积[27]:
也就是说,从初始条件,我们得到,其中表示和之间的线性卷积。如果我们将初始条件设置为零,那么正好是的线性卷积,具有残差连接。更一般地,任何线性时不变系统(其中SSM是一种特例)都可以写成卷积。
给定长度为的1D输入序列,我们将一个SSM(由矩阵、、、参数化)的1D输出序列表示为:
如果上下文清楚,为了简化表示法,我们省略了对、、、的引用,则写。当是维度的多维时,我们将这些SSM的堆叠在一起,使用相同的符号定义了从到的映射。
为了有效地从、、构造滤波器,通常被约束为对角[26、29],或对角加低秩[25]。
SSM通过FFT。 通过传统的矩阵运算简单地计算卷积对于长核(long kernels)来说是昂贵的,缩放为。相反,我们可以使用FFT:对和进行FFT,逐点将它们相乘,然后进行逆FFT。这产生了一个算法。
2.2 线性注意
我们描述了线性注意力[35]及其与RNN的联系,这启发了我们的模型设计(第3节)。
在标准注意[62]中,我们有个查询/键/值标记,, 其中是序列长度,是注意力头的维度。对于某些相似度量, 我们希望计算输出:
对于标准的softmax注意,(通常点积按缩放)。线性注意假设对于某些(非线性)函数,具有的形式。则输出为Oi=φ(Qi)>Pi j=1φ(Kj)V>jφ(Qi)>Pij=1φ(Kj)。设Si=Pij=1φ(Kj)V>j∈Rd×d,zi=Pij=1φ(Kj)∈RD,di=φ(Qi)>zi∈R,则Oi=φ(Qi)>Sidi。这将线性注意力与RNN联系起来:输出Oi是Si和zi的函数,这两个函数都是递增更新的(作为累积和)。
3饥饿饥饿河马层建模离散序列
为了理解SSM和对语言建模的注意之间的差距,我们考察了两个合成语言建模任务。这些任务激励我们的H3层添加离散SSM(基于移位矩阵)和乘法交互,以有效地建模离散序列。然后,我们展示了H3层的表达能力足以解决这些合成任务,并且这种理解可以在真实的语言建模基准上获得更好的性能。
3.1动机:合成语言建模任务
我们描述了两个密切相关的合成任务,总结在表1中。Olsson等人[49]认为,解决这些任务(变体)的能力占Transformer上下文学习能力的大部分,附录E给出了更多直觉。
“诱导头”任务测试模型在特殊令牌(例如表1中的“”)之后如何召回内容。在序列结束时,模型必须调用序列中较早的特殊令牌之后立即出现的令牌。联想回忆[1]类似于诱导头任务,但需要模型记住多个键值对。在序列结束时,模型必须调用属于特定键的特定值。
表2(对于两层模型)显示,S4D[26]和门控状态空间[42]都未能对这些合成语言进行建模,这表明它们可能不具备通用语言的表达能力。我们认为,这些失败表明了两个缺失的功能:(i)记住在特定事件之后出现的令牌(例如,诱导头任务中的特殊令牌),以及(ii)在整个序列中比较令牌(例如比较密钥以决定要调用的值)。注意力具有这两种能力:它可以通过构造二次注意力矩阵QK>来比较令牌,并且可以通过直接复制(将softmax(QK>)乘以V)来召回令牌。在第3.2节中,我们设计了新的H3层,以在SSM中实现这些功能,从而缩小SSM和注意力之间的表现力差距。
3.2 H3层
H3使用带有移位和对角矩阵的SSM,以及针对输入投影的乘法运算,以捕获合成物识别的缺失能力。
高级直觉。
(i) 为了记住过去的标记,我们希望状态xi从输入ui复制,然后将该信息传递给下一个状态xi+1。由于xi+1与xiby Axi相关,我们使用带有移位矩阵a(如下所述)的离散SSM来移位状态向量的元素(例如映射[a,b,c])→ [0,a,b])。(ii)为了在序列中比较令牌,我们使用乘法交互:SSM的输出(包含来自先前时间步骤的信息)与当前时间步骤的输入相乘,从而测量令牌之间的相似性。
H3受到线性注意力的启发(第2节):我们投影输入u以获得三个信号Q、K、V。然后我们用SSM替换非线性φ(K),其中A是移位矩阵(SSMshift),我们用带对角线A的SSM替换求和Si(SSMdiag)。对于头部尺寸dh=1的情况,输出为:
其中o表示逐点乘法。我们可以将这种形式视为叠加两个具有乘法交互作用的SSM(每个SSM都是一只“饥饿的河马”,因此我们的层的名称)。线性注意力、时变系统和H3之间更正式的联系见附录B。
记住关键标记:移位和对角SSM。
移位和对角SSM旨在解决特定事件后记录令牌的能力。在移位SSM中,我们将A∈Rm×m约束为移位矩阵Ai,j=(1代表i−1=j 0否则。此矩阵对隐藏状态xi的作用是将每个坐标下移一,从而创建以前状态的“记忆”。例如,如果B=e1,第一个基向量,则xi=[ui,ui−1,…,ui‐m+1]包含前m个时间步的输入。我们学习B和C(为了简单起见,B也可以固定为e1,在这种情况下,输出是具有内核大小m的1D转换器)。
对角SSM将A约束为对角,并根据HiPPO的对角版本对其进行初始化(S4D[26])。该参数化允许模型记住整个序列的状态。移位SSM可以检测特定事件发生的时间,对角SSM可以在随后的序列的其余部分中记住令牌。
用于比较的乘法交互。
我们从线性注意力中获取乘法交互,但当与移位矩阵结合时,它们提供了另一种缺失的功能:在序列中比较令牌。移位SSM的输出和V投影之间的乘法交互模拟了线性注意中的局部乘法交互(取决于隐藏状态的大小)。类似地,与Q投影和对角SSM输出的乘法交互允许在整个序列上对令牌进行比较。
H3层。
整个层在算法1中给出,并在图1(左侧)中示意性地显示。我们使用H3层以与Transformers相同的方式构建模型,方法是将其与MLP交错,通过剩余连接和层规范(即,预规范架构[2])进行连接。我们还将考虑混合H3注意力模型(两个注意力层,其余为H3,第3.3节和第5节)。
效率
我们表明,H3以O(N log N)的形式缩放,序列长度N比注意力更有效,这通常需要O(N2d)时间和O(N2)空间3(证据见附录D.3)。
提案1。
设N为序列长度,d为隐藏维度,并假设头部维度dh为O(1)阶。然后H3层需要O(d 2N+dN log N)时间和O(dN)空间来计算。
3.3表达能力
我们表明H3可以对我们的合成语言以及OpenWebText上的自然语言进行建模[23]。我们还提出了一种混合H3注意力扩展,其性能优于OpenWebText上的Transformer。
H3.解决联想回忆的机制。
H3的表达能力足以解决我们的合成语言建模任务,如表2所示。图1(中间)显示了单个H3层解决特定键值对(a,3)关联召回任务的机制。移位SSM和随后的乘法交互作为一个门,根据先前的令牌是否为密钥a,是否让一个值通过对角SSM。对角SSM将值3存储在存储器中,并持续输出。最终的乘法交互根据当前输入令牌是否为键a,来选择是否让对角SSM的输出通过。我们正式构建H3层的权重,以解决附录D.1中的任务。
更好的合成语言建模转化为更好的自然语言建模。
我们验证了当H3能够解决这些合成任务时,它还提高了自然语言(例如OpenWebText数据集)的建模能力。如表3所示,当在OpenWebText上训练50B令牌时,H3在Transformer的0.4困惑点内,并且表现比现有SSM变体(S4D、GSS)好3−3.9分。
扩展:H3注意力混合模型。
一个简单的混合H3注意力语言模型令人惊讶地超过OpenWebText上的Transformer1.0个百分点。我们的混合模型只保留了两个自注意层:一个在第二层,另一个在中间(N层模型的层2+N/2,甚至N层)。H3注意力混合型也优于GSS注意力混合型[42]。
4 FlashConv:有效培训SSM
为了提高SSM在现代硬件上的效率,我们提出了FlashConv。FlashConv融合了FFT、逐点乘法和逆FFT,以减少内存读取/写入。它还使用块FFT算法来使用专门的矩阵乘法单元(例如,A100上的张量核),序列长度可达8K。对于长度超过8K的序列,计算不再适用于GPU SRAM4,因此我们提出了一种新的状态传递算法,该算法将序列分割成块,以一次计算一个块的FFT卷积。FlashConv可以加速任何SSM(不仅仅是H3)。
4.1保险丝盒FFTConv
我们部署了两种技术来加快短于8K序列的基于FFT的卷积:核融合和块FFT。内核融合解决了由于读取和写入中间结果导致的IO瓶颈,而块FFT允许基于FFT的卷积利用专门的矩阵乘法单元。这些技术使我们能够将短于8k的序列的FFTConv速度提高2倍(第6节)。
内核融合。
由于重复读取和写入中间结果,使用标准库(如cuFFT)的FFTConv的简单实现是IObound的。具有输入u和滤波器f的SSM中的FFT卷积的形式为iF f T(f f T(u)o f T(f))(其中o表示逐点乘法)。
它需要将中间结果读取和写入GPU内存,而GPU内存可以控制运行时。在FlashAttention[15]之后,我们首先将整个FFTConv融合到单个内核中,并在SRAM中进行计算,以避免这种开销。
块FFT。
为了进一步加快基于FFT的卷积的计算,我们在现代GPU上开发了专门的矩阵乘法硬件(例如,Nvidia GPU上的Tensor核心执行快速16×16矩阵乘法)。我们引用经典的结果,这些结果表明FFT可以写成一系列交错排列的块对角矩阵乘法。我们注意到这样的算法并不新鲜,但我们的设置(GPU上的融合FFTConv)通过消除IO瓶颈引入了新的瓶颈,计算成为瓶颈(注意GPU上的单个FFT通常受IO限制)。
假设我们要执行N点FFT,这相当于乘以DFT矩阵FN。假设对于一些整数N1、N2,N=N1N2。通过DFT的Cooley-Tukey分解[3,11](也称为四步FFT算法),我们可以写FN=P(IN2⊗FN1)P>D(IN1 \8855;FN2)P,其中P表示将输入重新整形为N1×N2阵列然后转置的固定排列,\8855表示Kroneker乘积,D是N×N对角矩阵(称为旋转因子)[14],INi和FNi是尺寸为Ni×Ni的恒等式和DFT矩阵。由于IN2⊗FN1和IN1 \8855;FN2只是块对角矩阵,我们可以使用专门的matmul单元来执行这些乘法。类似地,如果N=N1N2N3,那么我们可以将N点FFT分解为大小为N1、N2和N3的一系列(块)FFT,通过置换进行交织。
块FFT算法对于序列长度N产生O(Nr log N/log r)FLOP,如果N可以被写成两个整数r,p的r p。这比标准FFT(O(N log N))产生更多FLOP,但当我们使用专用的矩阵乘法硬件时,可以运行得更快。
4.2状态通过
然而,如果序列太长,无法装入GPU SRAM(A100上的长度超过8K),则融合内核无法运行。我们展示了如何在SSM中利用FFT的特殊形式来加快长序列的速度。
SSM的递归性质允许我们将长度为N的序列的FFTConv分割成大小为N0的块(N0是我们可以装入SRAM的最长FFT),假设N是N0的倍数)。我们使用FFTConv来计算每个块,并使用递归来连接块。特别地,我们将输入u分成C=N/N0个块u(C)∈R N0,其中C=1,C、 类似地,如果i=1,C、 我们只需要每个块C的结束状态x(C)N0。
设f=[CB,CAB,CA2B,…,CAN0−1B]为SSM滤波器。回想第2节,对于每个块c,y(c)i=CAiBx(c−1)N0+(f*u(c))i+Du(c)i,由于x(c–1)N0,前一块(c−2)的结束状态是当前块c的初始条件。在向量表示法中,对于某些矩阵Mxy∈R N0×m,y(c)=Mxyx(c)N0+f*u。此外,对于某些矩阵Mm×N0 ux(附录c.2中的推导),我们需要使用x c N0=AN0 x(c−1)N0+Muxu(c)更新每个块的结束状态。本质上,只要我们记住前一块的结束状况,我们就可以使用基于FFT的卷积计算每个块的输出,并且每个块的终止状态可以重复更新。这产生了长序列的状态传递算法,其中我们只计算长度为N0的FFT,并在每次迭代时更新一些隐藏状态。
让BlockFFTConv引用我们的融合块FFTConv内核。然后,1D输入的状态传递算法由算法2给出。对于我们堆叠d个SSM的维度d的输入,我们简单地沿着d维度对算法2进行批处理。
我们证明,算法2产生的输出与使用大小为N的大型FFT计算SSM的输出相同(证据见附录D.4):
提案2。
对于输入u∈RN和矩阵A、B、C、D,算法2返回的输出y∈RN与由A、B,C、D参数化的SSM定义的输出相同。
5 H3评估
为了了解如何很好地捕捉第3.1节中的合成词转化为语言建模,我们训练了两个尺寸为125M、355M和1.3B的混合H3注意力语言模型,并评估了它们与Transformer的性能。混合模型在困惑和零/少样本学习方面达到或超过Transformer的质量。我们还验证了H3模型在非文本序列建模上保持了强大的性能。附录F包含更多数据集、长度外推和数据缩放的额外实验。
5.1语言建模
我们将混合H3注意力语言模型与基于Transformer的语言模型进行了比较。我们使用困惑、零样本学习和少样本学习来评估语言建模性能。混合H3模型优于Transformer,这表明缩小SSM和对合成语言的注意之间的差距可以转化为真正的语言建模能力。我们还报告了混合H3模型与Transformer相比的发电速度;由于SSM是循环模型,因此其生成令牌的速度比Transformer快1.6倍。附录F显示了纯H3语言模型在这些相同评估指标上的性能。
安装程序
我们在桩[21]上训练尺寸为125M、355M和1.3B的混合模型,以获得400B代币。我们与HuggingFace[65]的Open-AI[53]和GPT-Noo5[4]的同等大小的检查点进行了比较。
困惑
表4显示了堆[21]、OpenWebText[23]和WikiText-103[43]上的困惑。在桩上,我们的125M混合模型优于同样在桩上训练的GPT Neo。我们的混合模型在向OpenWebText和WikiText103的零样本传输上也优于GPT Neo模型和GPT-2模型。我们报告了GPT-2模型的PPL,但由于它们是在不同的数据上训练的,因此性能无法直接比较。
零炮和少炮性能
我们比较了混合H3语言模型与OPT[66]、GPT Neo和GPT-2模型的零样本和少样本性能,其中公共检查点可用。我们报告了可能选择的逻辑的性能和等级分类(参见附录F.6的原始生成)。表5报告了SuperGLUE基准测试的零炮性能,表6报告了三炮性能。在困惑的结果之后,混合语言模型在超过一半的任务上优于或匹配Transformer基线。
语言建模推理
最后,由于SSM是循环模型,它们承认比Transformer更快的文本生成。表7显示了与Transformer相比,125M参数混合模型的推理吞吐量。混合模型具有1.6倍的吞吐量。
6 FlashConv评估
我们评估FlashConv加速SSM的效果。FlashConv使用S4[25]在远程竞技场基准[59]上设置了最先进的性能。我们报告了使用FlashConv训练H3模块在256到32K的不同序列长度下的性能,并展示了几乎线性的缩放。
远程竞技场
远程竞技场(LRA)基准[59]是远程序列建模的基准。最先进的方法S4[28]是SSM。表8显示FlashConv将S4加速了2倍,比Transformers快了5.8倍。
针对注意力的基准H3
我们用FlashConv测试H3向前和向后传球的时间。FlashConv保持几乎线性的缩放,即使是非常长的序列长度。图2显示了使用我们的技术(块FFT、状态传递)的cuFFT在FFTConv上的总体2-3倍加速。简单的内核融合(即使没有块FFT)可以为短序列产生比cuFFT更快的速度,因为内存读/写是短序列的瓶颈。对于长序列,使用状态传递的SSM可以比最快的注意力实现快几十倍。
7结论
我们的主要目标是在建模能力和硬件效率方面理解并缩小语言建模中注意力和SSM之间的差距。我们基于合成语言任务的探索激励我们设计H3层,这令人惊讶地具有注意度。我们的BlockFFTConv算法利用了矩阵乘法单元和SSM的双循环卷积视图,大大加快了SSM的速度,减少了注意力和SSM之间的硬件障碍。我们对未来的几个方向感到兴奋。我们的H3层是两个SSM的简单组合,更复杂的设计可能更具表现力。我们在多达1.3B个参数的语言模型上取得的令人鼓舞的结果表明,将SSM扩展到更大的规模是一个有前途的途径。由于简单地为H3模型添加两个注意力层已经超过了纯H3模型和Transformer,我们对未来结合SSM和注意力的互补优势持乐观态度。
A相关工作
状态空间模型在建模时序数据方面显示出了前景,包括时间序列数据[25]、音频[22]和视觉数据[44]。我们的模型建立在简化和参数化S4对角线版本的基础上[26,28,29]。门控状态空间[42]也旨在使SSM适应语言建模,但我们的结果表明,GSS模型的性能不如H3(甚至不如S4D等早期SSM)。在混合模型中结合SSM和注意力的想法并不新鲜;Mehta等人[42]还表明,将注意力与GSS层交织可以提高性能,我们也在OpenWebText实验中验证了这一点。这些积极的结果表明,注意力和SSM是互补的,混合模型可能是未来工作的一个有希望的方向。
大型语言基础模型[5]已经证明了将基于注意力的网络扩展到数十亿个参数并在数万亿个令牌上训练它们的能力[32]。了解这些模型背后的机制基础[18]可能会对未来模型的更好设计选择产生见解。这些和类似的探索为H3的设计和合成语言的选择提供了信息。最近的一些工作还探讨了如何通过近似注意力计算来解决注意力不足[9,16,35,39,59,63]。我们相信这些努力是SSM的补充,我们很高兴看到它们如何在未来的工作中结合起来。
线性注意力[35]和RNN等经典序列模型是H3的灵感来源。附录B描绘了线性注意力和LTI系统之间的直接联系。Luo等人[40]还提出了一种线性注意的变体,可以实现序列长度的O(n log n)缩放。附录F评估了语言建模上的线性注意力,发现它不如精确注意力,而H3优于注意力。H3中的乘法交互让人联想到LSTM[31]和GRU[8]中的门控机制,这表明从这些序列模型中获得的架构教训可能有助于将SSM适应语言建模。还提出了许多将注意力缩放到较长序列的算法,如Transformer XL[13]、Reformer[39]、Performer[9]和Perceiver AR[30]。其中一些方法在语言建模方面表现不佳,而且可能在墙上的时钟速度上更慢[15]。对这些替代方案进行彻底的比较,以获得准确的注意,以及它们在模型大小和训练数据量方面的扩展程度,是未来富有成效的工作。
FFT算法应用广泛,包括信号处理[50]、控制理论[6]等。用于计算FFT的各种算法已经存在了几十年[51]。我们希望我们在吸引这些经典算法来加速新应用(如学习的SSM)方面的工作将激发未来的算法探索,即使硬件不是为它们设计的[33]。
B线性注意与时变系统
我们将线性注意力与LTI系统和SSM联系起来。
我们首先将线性注意力呈现为线性时变系统,并展示了将其转换为线性时不变系统如何匹配H3。
线性时变系统与线性注意
通常,序列模型中的层接收序列并输出序列。其中许多采用线性时变系统的形式(得益于皮卡德·林德洛夫定理,非线性系统可以由一系列线性系统近似):
这与SSM(第2节)的形式相同,只是矩阵可以取决于时间步长。
回想第2节中线性注意力的形式。为了近似,我们忽略了线性注意第2节中的分母(即,假设di=1)。我们看到Si只是一个累积和,满足Si+1=Si+φ(Ki+1)VTi+1的递推。类似地,Oi满足递推Oi+1=φ(Qi+1)T Si+1。这是一个形式为xi+1=Axi+Bui+1和yi+1=Ci+1xi+1的线性时变系统(其中a=I,B=I,ui=φ(Ki)V T I,Ci=φ(Qi)T)。也就是说,A和B是常数,但C是时变的。
为了将其转换为线性时不变版本,我们将时变Ci视为后处理步骤。
我们将LTI改为固定C。这产生LTI:
对于学习的一些矩阵A、B、C。然后,我们通过将yi+1与φ(Qi)T相乘来应用后处理。用移位SSM代替φ(Ki)产生与H3类似的结果。
C方法详情
由于我们在第3节中描述了前向传球,因此我们在这里详细描述了后向传球。
C、 1向后传球
我们展示了如何计算融合内核中的反向传递。
设y=f*u+Du。在我们的例子中,f和u具有相同的长度,因此就卷积而言,它们是对称的。
假设我们得到dy=ύlύy(其中l是一些损失函数)。我们希望计算du、df和dD(它们分别为:。
最具挑战性的部分是通过卷积算子计算梯度,但我们将看到我们可以重用FFT基础设施;我们有dD=dyuT。
卷积梯度
在这里,我们将讨论如何通过对卷积算子*的w.r.t进行积分来计算df。作为一个直接的结果,我们也可以计算du。
由于f和u具有相同的长度L,因此f*u和u*f具有相同的结果。因此,我们将从这里的u*f开始。
对于某些符号,设O=u*f。然后,dO=dy。回想一下O[i]=Pi−1j=0u[i−j]f[j]。
我们将首先用零扩展u和f,使其长度为2L。设u0=[u[0],u[1],u[L−1],0,0]和f0类似地扩展。设O0=u0*f0,且O=O0[:N]。假设我们有dO0的所有值(我们只有前半部分的值,但我们会看到这最终无关紧要)。
让我们构造一个Toeplitz矩阵Hu0,使得u 0*f 0=Hu0f 0:
由于对于i≥L,我们有u0[i]=f0[i]=0,因此我们实际上也可以填充上述矩阵的零:
然后,我们可以使用矩阵乘法链规则来发现:
其中我们使用u0[−i]表示u0[2L−i]。请注意,此矩阵的格式与Hu0!设u0*=[u0[0],u0[-1],u0[−(2N−1)]]。然后:
那么我们如何有效地计算u0*?天真地,我们可能会遇到一些严重的内存访问问题。但是DFT的一个好特性拯救了我们!设U[i]为信号U的DFT的第i个元素。注意,U[i]是复数。我们有:
其中*表示复合共轭物。我们可以使用此属性有效地计算df0:
其中F F T*表示取FFT的复共轭,dy0表示dy,用零填充。
计算du
我们可以使用相同的技巧来计算du,除了我们需要添加原始du项的贡献。我们最终得到:
C、 2状态传递矩阵
我们展示了如何在状态传递算法中导出状态更新的多路复用器。我们希望构造一个矩阵vMux∈Rm×N0,使得Muxu=PN0 i=1 AN0−1Bui。注意AiB∈Rd×1是列向量。我们可以简单地堆叠这些列向量以形成矩阵:Mux=[AN0−1B,AN0−2B,…,B]。
D校对
我们展示了H3的参数化和解决联想回忆任务的注意力。我们证明了命题1和命题2。
D、 1 H3表达
本节正式描述了解决关联召回任务的H3参数化。
D、 1.1示例语言∧
考虑一种具有4个键和4个值的简单语言。为了具体,我们将使用键{k1,k2,k3,k4}=LK和值{v1,v2,v3,v4}=LV,即我们的语言L=LKûLV。给定一系列键-值对,末尾有一个键,我们需要一个模型来生成与末尾的键关联的值。假设末尾的键出现在序列中。
更正式地说,让N+1是序列的长度,N偶数。语言∧由序列x∈LN+1组成。每个序列都有一个相关的映射fx:LK→ 低压。对于每个序列,奇数索引从LK中随机采样,对于x1、x3、,xN−1。偶数指数由fx:x2*i=fx(x2*i−1)定义,其中1≤i≤N/2。序列xN+1中的最后一个项是从已经出现在x中的键中随机抽取的,即xN+1∈∈{x1,x3,…,xN−1}。该语言建模任务的目标是在序列末尾生成fx(xN+1)。
D、 1.2求解∧的H3模型
我们描述了一个可以求解∧的玩具H3模型。考虑由嵌入层、H3模型和具有softmax的输出投影组成的模型。回想一下,d是H3模型的维度,m是其隐藏状态的维度,H是头部的数量。设d=8,m=2,H=4。让嵌入层将每个键ki映射到ei基向量,并将每个值vi映射到e4+i基向量。
设Bshif t和Cshif t表示移位SSM的参数,Adiag、Bdiag和Cdiag表示对角SSM的值(设D为零)。设Bshif t=Bdiag=Cdiag=e1。设Cshif t=[01]。设Adiag是对角矩阵,每个H3的对角线上有1。
评论
由Adiag、Bdiag和Cdiag参数化的对角SSM的作用是作为其所有输入的累积和。由Bshif t和Cshif t参数化的移位SSM的作用是将其输入移位一个时间步长。
回想H3层通过应用uWQ、uWK和uWV将其输入映射到Q、K和V。设WQ和WK如下:
回想一下,Q和K被分成H个头部(对于i∈{1,2,3,4},Q(i),K(i)),每个头部被发送到独立的H3。
评论
WQ和WK的作用是将每个键“分配”给不同的H3磁头,即,当xt=ki时,Q(i)t仅为非零。类似地,当xt−1=ki时K(i)t仅为非零时(由于移位SSM的时间延迟,Kt=Kt−1)。
设WV如下:
评论
该矩阵的作用是对输入值进行编码(作为“二进制”),并将其发送给所有H3头。E、 例如,V(1)t=V(2)t=V(3)t=所有i的V(4)t,V(i)t=[0,0]⇔ xt=v1,V(i)t=[0,1]⇔ xt=v2等。
我们声称,对于xN+1=ki,O(i)N+1将是fx(ki)的二进制编码的倍数,并且输出O(j)N+1,1≤j≤4,j 6=i的所有其他头部将为零。假设输出投影WO是这样的,在之后具有非线性,它反转二进制编码以产生期望输出fx(ki)的嵌入。我们将假设这样的投影存在,证据留给读者。
提案3。
上述模型解决了∧语言的联想回忆问题。证据校样草图。WLOG,设xN+1=ki。那么Q(i)=[1,1],但对于j=6i,Q(j)=[0,0]。因此,由于乘法相互作用,对于j6=i,O(j)=[0,1]。
由于Q(i)=[1,1],O(i)是H3磁头中对应于ki的诊断SSM的输出(请记住,每个磁头有两个独立的移位SSM和两个独立诊断SSM)。诊断SSM的输出是他们在序列中看到的所有输入的累积和。
要使一个诊断SSM看到非零输入,其前一个移位SSM必须具有非零输出。序列中只有当xt−1=ki时才会发生这种情况。但随后xt=fx(ki)。因此,诊断SSM的输入精确地是fx(ki)的二进制编码。然后输出O(i)是fx(ki)的二进制编码的倍数,WO将该输出解码为fx(i)的嵌入
D、 2注意力表达
我们提供了一个两层注意力模型的非正式草图,该模型可以解决联想回忆任务,灵感来自[49]的构建。注意力模型的第一层输出序列中先前令牌的嵌入,并将其与序列中的当前令牌连接。第二层将当前令牌与先前的令牌嵌入进行比较,并在存在正好是键值查找的匹配时输出成对嵌入。
施工过程如下:
•在第一层中,将Qi映射到标记xi−1的位置嵌入(例如,如果pi表示标记xi的位置嵌入,则为pi−1),将Ki映射到标记xi的位置嵌入注意力矩阵A被计算为QKT,具有因果掩码(即,如果j>i,则Ai,j=0)。
•然后,softmax(A)近似于移位矩阵(见第3节)。
•让Vi是标记xi的编码,限制在隐藏维度的前半部分。
•然后,对于输出O=softmax(QKT)V,矢量Oi的前半部分是标记xi−1的编码。
•在第二层中,假设您有一个跳过连接,它将输入标记xi的编码映射到向量Oi的后半部分然后,第二层的输入对xi−1和xi进行编码在第二层中,让Qi提取xi的编码,让Ki提取xi−1的编码。
•在QKT上涂上因果面具。然后,如果xi=xj−1,且i>j−1.,则softmax(QKT)i,j的值较大。
•让Vi提取xi的编码。•然后,输出Oi是值xj的总和,例如xj−1=xi。
但Oi恰恰是对xi之后出现的标记的查找,而它之前出现在序列中,正好解决了联想回忆。
我们注意到,上述构造要求位置编码能够基于点积和softmax选择先前的令牌,并通过点积和softmax进行令牌比较。
D、 3 H3复杂性
我们证明了命题1,该命题指出H3层对于序列长度N和隐藏维度d需要O(d2N+dN log N)时间和O(dN)空间。
证据
我们首先分析时间复杂性。考虑H3中的矩阵乘法,其中输入u∈R N×d乘以三个大小为d×d的权重矩阵。这需要时间O(d 2N)。输出O还与大小为d×d的输出投影权重矩阵相乘,同样需要时间O(d 2N)。因此,矩阵乘法需要时间O(d 2N)。
现在考虑H3中的两个SSM。第一个SSM涉及K∈R N×d(在N维)与大小为N×d的核的卷积。这简化为FFT、逐点乘法和逆FFT(在N维度)。这需要时间O(dN log N)。第二个SSM涉及H个卷积,沿N维输入大小为N×dh×dh。这需要时间:
其中我们使用dh=d/H和dh=O(1)的事实。因此,两个SSM需要总时间O(dN log N)。因此,H3层需要时间:
现在我们分析空间复杂性。矩阵乘以所有取空间O(dN)。两个SSM的FFT、逐点乘法和逆FFT取O(dN)空间,O(Hd2-hN)=O(ddhN)=O(dN)空间。
因此,总体空间复杂度为O(dN)。
D、 4状态传递正确性
我们证明了命题2。我们假设BlockFFTConv算法是正确的,即输出y=BlockFFTCon v(f,u)等于具有卷积核f和输入u的SSM的输出。证明。C上的归纳证明
基本情况:
C=1。WTS y=[y(1)]、Mxxx(0)N0+Muxu(1)=xN。在这种情况下,请注意N=N0。那么y(1)=Mxyx(0)N0+BlockFFTConv(f,u1)=BlockFFTCon(f,u1)。但是u=u1,所以y=y(1)=[y(1(1)]。
另外通过状态空间的递归定义,
感应步骤:
C>1。假设[y(1),…,y(C−1)]=y[:N0(C−2)],x(C−3)N0=x(C–1)N0。WTS为y(C)=y[N0(C−1):N0C],Mxxx(C−2)N0+Muxu(C)=xN。设t表示N0(C−1)。对于i>(C−1)N0,我们有:
因此,y(C)=y[N0(C−1):N0C]。类似地,
E实验细节
E、 1合成物
我们的合成任务受到[49]的启发,旨在模仿大型语言模型的上下文学习能力,即从输入序列中的示例学习的能力,并将输入中的信息用于
为输出生成正确答案。例如,诱导头任务需要记忆输入序列中出现在特殊“标记”之后的标记,关联回忆任务需要学习输入序列中从键到标记的映射。
我们通过训练GPT模型的两层版本来评估合成,不同的模块取代了注意力。我们训练具有内部维度32和MLP维度128的模型。对于所有合成物,我们使用5e-4的学习率和0.1的权重衰减。我们从同一分布中抽取了5000个训练示例和500个测试示例,我们训练了200个时期。同样,我们使用0.1的嵌入丢失和0.0的残余丢失。
E、 2模型架构
对于125M模型,我们使用了12层,隐藏维度为1024,MLP维度为4096。对于355M型号,我们使用24层,具有相同的隐藏维度和MLP维度。对于混合模型,我们使用16个注意力头部作为注意力层。125M混合模型在第1层和第7层具有注意层,而355M混合模型则在第1和第13层有注意层。对于我们的混合型和H3型,我们使用SSM状态大小64。我们的混合模型使用H3的头部尺寸1,而我们的纯H3模型使用头部尺寸8。我们使用混合精度训练运行模型,bf16用于MLP和注意力。在训练语言模型时,我们将fp32用于FFTConv。
E、 3 OpenWebText培训
对于在OpenWebText上训练的125M模型,我们遵循Megatron LM repo的训练配方。
我们使用512的有效批处理大小,并使用梯度累积来适应可用的GPU内存。我们使用AdamW优化器,GPT-2小样本的学习率为6e-4,而GPT-2中等样本的学习速度为1.5e-4,权重衰减为0.1。所有模型都用相同的超参数训练100K步。我们使用混合精度训练(PyTorch AMP)运行所有实现。我们训练序列长度为1024的模型。
我们使用Openwebtext数据集和GPT-2 BPE标记器。我们随机选择数据集的0.5%作为验证集,其余的用作训练集。验证集的随机选择只进行一次,所有模型都在同一个验证集上进行评估。
E、 4桩训练
对于在桩上训练的125M和355M模型,我们遵循GPT-3的训练配方。我们使用批量大小256,序列长度2048。我们为800K步数训练模型。我们使用残差丢弃0.0和嵌入丢弃0.1。我们使用AdamW优化器,125M模型的学习率为6e-4,355M模型的为3e-4,权重衰减为0.1。我们使用具有8000步的余弦计划进行线性预热,并通过300B令牌将学习率衰减到10%,然后以10%的学习率继续训练另一个100B令牌。我们怀疑H3语言模型存在更好的超参数,但我们没有资源来调整它们。
对于1.3B模型,我们将批量大小加倍到512(序列长度为2048),同样遵循GPT-3的训练配方。训练步骤的数量减半,以便我们使用相同数量的令牌进行训练。
对于Pile数据集,我们再次使用GPT-2 BPE标记器,类似于GPT-3和GPT-Neo。
E、 5超级胶水
我们遵循GPT-3论文中使用的提示[7]。对于二进制分类任务的秩分类,我们对WSC、WIC、MultiRC和BoolQ使用yes/no,对RTE使用true/false。对于CB,我们使用true/false/nother作为三个选项。对于COPA和ReCoRD,我们使用任务提供的延续。
E、 6硬件
所有模型都在单个16xA100-40GB节点或8xA100-80GB节点集群上训练。
F附加实验
F、 1维基文本103
我们在WikiText103[43]上训练125M大小的模型,并将其测试PPL与Transformer以及其他有效或长期注意的变体进行比较。我们使用与OpenWebText训练相同的超参数和设置。我们还为上下文提供了Transformer XL和Perceiver AR的结果,尽管由于模型大小和标记器的差异,结果可能无法直接比较
表9显示,混合H3型号与相同尺寸的Transformer以及358M Perceiver AR和285M Transformer XL等更大型号相比具有竞争力。混合H3模型在性能、改革者和线性注意力方面也显著优于Transformer。
我们注意到Transformer XL和Perceiver AR PPl数字来自原始论文,可能无法直接与我们的结果进行比较。特别是,他们使用了不同语音大小的标记器,这意味着PPL不能直接进行比较。此外,较大的人声大小需要改变可能影响性能的模型(自适应softmax)。表9中排名前五的数字采用相同的设置进行训练,并且可以直接相互比较。
F、 第19页第2页
我们评估了在PG-19数据集[54]上训练的模型,PG-19是一个由书籍文本组成的自然语言数据集。我们比较了混合H3与Transformer和线性注意力的性能。我们使用与OpenWebText评估相同的设置。
F、 3长度外推
SSM的一个特性是,它们可以自然推断出比训练期间看到的序列长度更长的序列长度。我们使用合成联想回忆任务来证明H3保持这种能力。为此,我们在从联想回忆合成语言中提取的长度为20的序列上训练两层H3模型。然后,我们评估长度为20和40的序列上最后一个令牌预测的准确性。
F、 4代币数量的缩放
与Transformer相比,我们评估了混合H3模型与训练期间看到的令牌数量的比例。对于这些实验,我们为5B、10B和15B代币在桩上训练了125M混合H3模型和125MTransformer。我们使用6e-4的学习率,将热身调整为总训练时间的1%,并调整衰减率以在训练结束时将学习率衰减至6e-5。
表12显示了结果。混合H3模型和Transformer模型都随着训练令牌数量的增加而改进。
F、 5 H3语言模型
我们报告了纯H3语言模型对NLP评估的结果。我们在Pile上训练了一个125M模型,用于400B代币。表13和表14分别显示了SuperGLUE的零射和少射性能。
F、 6发电性能
我们报告了生成SuperGLUE的结果。我们不采用等级分类,而是让模型生成响应,并在输出中搜索金色标签(即,“是”或“否”表示是/否问题)。表15和16报告了结果。少样本学习的趋势与logit结果相匹配,但混合和H3模型在某些任务上的零样本表现非常差。在这些情况下,模型倾向于生成与答案无关的长文本响应。几个样本学习示例帮助模型以可解析的格式生成答案。
F、 7非文本序列建模
我们表明,H3在两个非文本序列建模任务上优于Transformers:原始语音分类和癫痫发作分类。H3在癫痫分类方面具有最先进的性能,并在语音分类方面与S4相匹配,这表明H3或其混合体可能是多模态基础模型的有力候选。附录E给出了实验细节,附录F给出了大脑fMRI数据的附加实验。
脑电图癫痫发作分类
癫痫是最常见的神经疾病之一,其特征是大脑活动失控[20]。慢性癫痫或癫痫会导致一系列精神和心理社会障碍,并影响全球约1%人口的生活[37]。治疗癫痫的第一步是由董事会认证的神经学家手动分析头皮脑电图。然而,每个患者产生的大量EEG数据(可能长达数天的数据)使得人工EEG分析成为一个昂贵且耗时的过程。
为了降低与EEG监测相关的成本,最近的深度学习技术开始显示出在标记潜在癫痫事件的异常EEG片段方面的前景[57]。对EEG数据进行分类的一个挑战是在增加输入序列长度和在长序列上训练深度学习模型的难度增加(例如,以200Hz采样的EEG信号每分钟产生12000个时间步长)之间进行权衡,其中,更多的上下文已被证明可以提高癫痫分类性能[55]。因此,许多技术涉及领域专用模型和预处理步骤,如FFT变换和图形表示[58]。
我们使用了最大的公开可用EEG语料库TUSZ v1.5.2[56],其中包括636名患者的5612个EEG信号,以及3050个带注释的癫痫发作。信号被分割成60秒的片段,并由患者分成训练/评估/测试。训练集包含39765个剪辑,val集包含4351个剪辑,测试集包含10001个剪辑
我们在TUSZ v1.5.2[56]语料库上评估了以200Hz采样的60秒EEG片段的二值发作分类,其中19个电极:x∈R12000×19和y∈{0,1}。Transformer无法在GPU内存耗尽的情况下处理长序列的EEG信号,而H3可以并设置最先进的性能。
原始语音分类
SC10语音命令任务[64]包含长度为1秒的原始音频信号,采样频率为16kHz。与EEG信号类似,Transformer无法处理长序列长度。表18显示,H3在S4的半个点以内,这是最先进的方法。
功能磁共振成像数据
功能磁共振成像(fMRI)数据通常以四个维度表示,表示在三维体积V∈R x×y×z的时间序列S={V1,…,Vt}中测量的血氧水平依赖(BOLD)信号,每个都表示大脑所有空间位置的BOLD信号(由三个空间维度x,y和z定义)。fMRI数据分析的一个关键挑战是其数据集的高维度和低样本量,其通常包含数十到数百个序列S中数百个体积V中的每一个的数十万维(即体素),标准的机器学习方法容易过度拟合
尽管单个数据集的样本量较低,但神经成像研究可以被视为最近进入了大数据时代,因为研究人员更频繁地公开分享他们收集的数据集[41]。这些数据的可用性为大规模神经成像的预训练提供了机会,如最近的[60]所示,使模型能够利用从公共功能神经成像数据中学习到的知识来分析单个数据集。具体而言,[60]通过首先在广泛的fMRI数据集上预训练模型来评估功能性神经成像数据的几个自我监督学习框架的性能,该数据集跨越了11980个fMRI数据,来自1726个个体,跨越了34个数据集,随后将预训练模型调整为两个下游的精神状态解码数据集(即HCP[61]和MDTB[38]数据集)。在心理状态解码中,预测模型的任务是从测量的大脑活动中识别(即解码)一些心理状态(例如,回答关于散文故事或数学问题的问题)。作者发现,在因果学习框架中预先训练的基于GPT的模型在解码两个下游数据集的20(HCP)和26(MDTB)心理状态方面表现最好。
为了评估H3在fMRI数据上的表现,我们使用[60]发布的上下游fMRI数据集复制了这一分析,将H3作为GPT模型的替代品。为了缓解fMRI数据的高维度挑战,并且由于大脑活动的空间相关性通常很高,原作者将体积时间序列S缩减为一组Θ∈θ1。。。,n=1024个功能独立的脑网络θ的θn(如功能模式字典(DiFuMo)脑图谱[12]所定义),每个都描述了体素vx,y,z∈V的某个子集的BOLD信号,使得生成的序列X∈R t×n描述了每个脑网络n在时间点t的活动模式。
根据[60],我们首先对模型f(·)进行预训练,以预测输入序列X的下一个时间点j的大脑活动分布,使用平均绝对误差(Lrec)训练目标,给定模型的输出Xû∈R t×n:Lrec=1 n Pn i=1|Xj,i−Xûj,i|;XÜt,n=bn+P n f(EX)t,母羊,n;EX t,e=ET R+Epos+be+P n Xt,nwn,e。这里,ET R∈R e和Epos∈R e表示输入序列的每个可能时间点和位置的可学习嵌入(有关详细信息,请参见[60])。由于fMRI的采样频率是可变的,输入序列的相同位置可以对应于不同的时间点。注意,f(·)在低维嵌入表示EX∈R t×e(e=768维)中处理输入。
我们评估了f(·)的两种模型架构,即[60]中使用的GPT变体,具有4个隐藏层和12个注意力头部,以及相应的H3变体,具有四个隐藏层(H=64,m=1;参见第3节)。对于这两个模型,使用最后一个模型层的隐藏状态输出序列来确定Xû。
正如[60],我们通过随机指定34个上游数据集中每一个的fMRI运行的5%作为验证数据(每个数据集至少运行2次),并将其余运行用于训练,将上游数据随机划分为不同的训练和验证数据集。在上游学习过程中,我们从fMRI运行中随机抽取10到100个时间点的序列,并使用ADAM优化器(β1=0.9,β2=0.999,c=1e−8)训练模型5000个步骤,最小批量为512,学习率为5e−4。我们应用线性学习率衰减计划(热身阶段为训练步骤总数的1%)、1.0的梯度范数限幅和L2正则化(权重为0.1)。我们还对基于GPT的模型(基于[60])应用0.1的辍学率,并评估H3的三个辍学率:0.1、0.2和0.3。
我们发现,在平均绝对误差方面,用0.2脱落训练的H3变体与GPT模型表现相同(图3),因此继续使用该模型变体进行所有进一步分析。我们还发现,这两个模型在整个大脑中显示出几乎相同的Lrec误差分布,在顶叶后部、枕部和扣带回皮质以及边缘系统的部分中误差相对较高(图4)。
为了使预训练的模型适应心理状态解码,我们将嵌入Ecls∈Rn的可学习分类添加到输入序列X的末尾,并将模型的预测f(EX)转发到解码头p(·),包括一个具有e个模型单元的密集隐藏层(每个嵌入维度一个,具有tanh激活),以及一个具有一个模型单元i的softmax输出层,用于数据中每个考虑的心理状态。因此,我们通过优化标准交叉熵损失目标来调整模型:Lcls=−P i yi log P(f(EX))i,其中yi表示二进制变量,如果i是正确的精神状态,则为1,否则为0。
在下游适应期间,我们开始使用各自的预训练模型参数进行训练,然后允许所有参数自由改变。类似于[60],我们将两个下游数据集中的每一个随机分成不同的训练和测试数据集,每个数据集包括40个(HCP)或10个(MDTB)不同的个体。我们以256的小批量大小和5e−5的学习率(否则使用与上游训练相同的学习参数)调整750个步骤的模型。重要的是,我们使用不同的随机种子重复每个下游训练运行20次,导致数据的不同随机分割和其他非确定性训练因素的可变性(例如随机初始化和数据洗牌)。
对于上游数据,基于H3和GPT的模型在两个遗漏的测试数据集中的精神状态解码性能通常表现相同(表19)。