摘要
大型Transformer模型在许多任务上通常能够取得最先进的结果,但训练这些模型的成本可能非常高,尤其是在处理长序列时。我们引入了两种技术来提高Transformer的效率。首先,我们使用局部敏感哈希(Locality-Sensitive Hashing, LSH)替代点积注意力机制,将其复杂度从O(L²)降低到O(L log L),其中L是序列长度。此外,我们使用可逆残差层替代标准残差层,这使得在训练过程中只需存储一次激活值,而不是N次(N为层数)。由此产生的模型——Reformer,在性能上与Transformer模型相当,但在长序列上具有更高的内存效率和处理速度。
1 引言
Transformer架构(Vaswani et al., 2017)被广泛应用于自然语言处理领域,并在许多任务中取得了最先进的结果。为了获得这些结果,研究人员不断训练更大的Transformer模型。在最大配置中,每层的参数数量超过0.5B(Shazeer et al., 2018),而层数则高达64层(Al-Rfou et al., 2018)。Transformer模型也被用于处理越来越长的序列。例如,在(Liu et al., 2018)中,单个样本处理的文本长度高达11,000个标记,而在处理其他模态(如音乐(Huang et al., 2018)和图像(Parmar et al., 2018))时,更长的序列更是司空见惯。这些大规模长序列模型虽然取得了优异的结果,但也对资源造成了巨大压力,以至于有人认为这种趋势正在破坏NLP研究¹。许多大型Transformer模型只能在大型工业研究实验室中训练,而使用模型并行训练的模型甚至无法在单个GPU上进行微调,因为它们的内存需求需要多加速器硬件设置,即使只进行单步训练。
大型Transformer模型是否从根本上需要如此庞大的资源,还是它们只是效率低下?考虑以下计算:最大Transformer层中使用的0.5B参数占用了2GB内存。对于64K个标记、嵌入大小为1024、批量大小为8的激活值,占用了64K × 1K × 8 = 0.5B个浮点数,需要另外2GB内存。如果我们的内存使用仅按层计算,那么我们完全可以轻松地在单个加速器上处理长度为64K的序列。此外,用于训练BERT的整个语料库仅需要17GB存储空间。那么,为什么我们甚至无法在单台机器上微调这些模型?
上述估计仅包括每层内存和输入激活值的成本,并未考虑Transformer中以下主要内存使用来源:
- 在具有N层的模型中,由于需要存储激活值以进行反向传播,内存使用是单层模型的N倍。
- 由于中间前馈层的深度dff通常远大于注意力激活值的深度dmodel,因此它占据了内存使用的大部分。
- 对于长度为L的序列,注意力的计算和内存复杂度均为O(L²),因此即使是单个64K标记的序列也可能耗尽加速器内存。
我们引入了Reformer模型,通过以下技术解决了这些问题:
- 可逆层(首次由Gomez et al., 2017提出)使得在整个模型中只需存储一份激活值,因此N倍的因子消失了。
- 在前馈层内分割激活值并分块处理,消除了dff因子,节省了前馈层内的内存。
- 基于局部敏感哈希的近似注意力计算,将注意力层中的O(L²)因子替换为O(L log L),从而允许处理长序列。
我们研究了这些技术,并表明它们对训练过程的影响可以忽略不计。分割激活值实际上只影响实现方式;它在数值上与Transformer中使用的层完全相同。使用可逆残差层替代标准残差层确实会改变模型,但在我们实验的所有配置中对训练的影响微乎其微。最后,注意力中的局部敏感哈希是一个较大的改变,可能会影响训练动态,具体取决于使用的并发哈希数量。我们研究了这一参数,并找到了一个既高效又能产生接近完整注意力结果的值。
我们在合成任务、文本任务(enwik8,序列长度为64K)和图像生成任务(imagenet-64生成,序列长度为12K)上进行了实验。在这两种情况下,Reformer都达到了与完整Transformer相当的结果,但运行速度更快,尤其是在文本任务上,并且内存效率提高了数个数量级。
2 局部敏感哈希注意力
点积注意力。Transformer中使用的标准注意力是缩放点积注意力(Vaswani et al., 2017)。输入包括维度为dk的查询和键,以及维度为
d
v
d_v
dv的值。查询与所有键的点积被计算出来,并除以
d
k
\sqrt{d_{k}}
dk进行缩放,然后应用softmax函数以获得值的权重。在实践中,注意力函数在一组查询上同时计算,这些查询被打包成一个矩阵Q。假设键和值也被打包成矩阵K和V,则输出矩阵定义为:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
(
1
)
\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V\qquad{(1)}
Attention(Q,K,V)=softmax(dkQKT)V(1)
多头注意力机制
在Transformer中,并不是使用单一的注意力函数来处理维度为dmodel的键、值和查询,而是将查询、键和值分别通过不同的、可学习的线性投影进行h次线性投影,投影到 d k 、 d k d_k、d_k dk、dk和 d v d_v dv维度。然后,对这些投影后的查询、键和值并行应用注意力机制,生成维度为dv的输出值。这些输出值被拼接起来并再次投影,最终得到结果。这种机制被称为多头注意力机制。
内存高效的注意力机制
为了计算注意力机制的内存使用情况,我们重点关注公式1中的注意力计算部分。假设Q、K和V的形状均为[batch size; length; dmodel]。主要问题在于QKᵀ项,其形状为[batch size; length; length]。在实验部分,我们在长度为64K的序列上训练模型——在这种情况下,即使批量大小为1,这也是一个64K × 64K的矩阵,以32位浮点数计算,将占用16GB内存。这是不切实际的,并且阻碍了Transformer在长序列上的应用。但需要注意的是, Q K T QKᵀ QKT矩阵并不需要完全存储在内存中。实际上,可以为每个查询 q i q_i qi单独计算注意力,只需在内存中计算一次 s o f t m a x ( q i K T d k ) V softmax(\frac{q_{i}K^{T}}{\sqrt{d_{k}}})V softmax(dkqiKT)V,然后在反向传播时根据需要重新计算梯度。这种计算注意力的方式可能效率较低,但它仅使用与序列长度成比例的内存。我们使用这种内存高效的注意力实现来运行实验部分中的完整注意力基线。
Q、K、V的来源
上述多头注意力机制操作的是键、查询和值,但通常我们只有一个形状为[batch size; length; dmodel]的激活张量A—例如,将句子中的标记嵌入为向量后得到的张量。
为了从A构建Q、K和V,Transformer使用了3个不同的线性层,将A分别投影到Q、K和V,每个线性层具有不同的参数。对于使用LSH(局部敏感哈希)注意力的模型,我们希望查询和键(Q和K)是相同的。这可以通过使用相同的线性层将A投影到Q和K,并使用另一个单独的线性层投影到V来实现。我们将这种模型称为共享QK的Transformer。实验表明,共享QK不会影响Transformer的性能,即使我们额外对键K的长度进行归一化,如实验部分第5节所示。
哈希注意力
对于LSH注意力,我们从两个张量开始:Q=K和V,形状为[batch size; length; dmodel]。我们保留了多头机制,并专注于公式1中的注意力计算。如前所述,主要问题在于 Q K T QKᵀ QKT项,其形状为[batch size; length; length]。但需要注意的是,我们实际上只对softmax(QKᵀ)感兴趣。由于softmax主要由最大的元素主导,因此对于每个查询qi,我们只需要关注K中与qi最接近的键。例如,如果K的长度为64K,对于每个qi,我们可能只需要考虑一小部分键,比如32或64个最接近的键。这样效率会高得多,但我们如何快速找到键中的最近邻呢?
局部敏感哈希
在高维空间中快速找到最近邻的问题可以通过局部敏感哈希(LSH)来解决。如果一种哈希方案将每个向量x分配到一个哈希值h(x),并且附近的向量大概率会得到相同的哈希值,而远处的向量则不会,那么这种哈希方案被称为局部敏感的。在我们的场景中,实际上只需要满足以下两点:1)附近的向量大概率会得到相同的哈希值;2)哈希桶的大小大概率相似。
我们通过以下随机投影方法实现这一点(见图[1])。为了得到 b b b个哈希值,我们首先固定一个大小为 [ d k , b ^ / 2 ] [d_k, \hat{b}/2] [dk,b^/2]的随机矩阵 R ~ \tilde{R} R~。然后定义 h ( x ) = arg max ( [ x R ~ ; − x R ] ) h(x) = \arg\max([x\tilde{R}; -xR]) h(x)=argmax([xR~;−xR]),其中 [ u ; v ] [u; v] [u;v]表示两个向量的拼接。这种方法是一种已知的LSH方案(Andoni et al. 2015),易于实现并适用于批量向量。
LSH注意力
在了解了我们的LSH方案和哈希注意力的基本思想后,我们现在将形式化本文中使用的LSH注意力。我们首先重写普通注意力的公式(1),每次针对单个查询位置
i
i
i:
o
i
=
∑
j
∈
P
i
exp
(
q
i
⋅
k
j
−
z
(
i
,
P
i
)
)
v
j
其中
P
i
=
{
j
:
i
≥
j
}
(
2
)
o_{i}=\sum_{j\in\mathcal{P}_{i}}\exp\left(q_{i}\cdot k_{j}-z(i,\mathcal{P}_{i})\right)v_{j}\quad\mathrm{其中~}\mathcal{P}_{i}=\{j:i\geq j\}\qquad{(2)}
oi=j∈Pi∑exp(qi⋅kj−z(i,Pi))vj其中 Pi={j:i≥j}(2)
我们引入符号
P
i
\mathcal{P}_i
Pi来表示位置
i
i
i的查询所关注的集合,
z
z
z表示分区函数(即softmax中的归一化项)。为了清晰起见,我们还省略了除以
d
k
\sqrt{d_k}
dk的缩放操作。
为了支持批处理,我们通常会在一个更大的集合
P
~
i
=
{
0
,
1
,
…
,
l
}
⊇
P
i
\widetilde{\mathcal{P}}_i=\{0,1,\ldots,l\}\supseteq\mathcal{P}_i
P
i={0,1,…,l}⊇Pi上执行注意力计算,同时屏蔽掉不在
P
i
\mathcal{P}_i
Pi中的元素:
o
i
=
∑
j
∈
P
~
i
exp
(
q
i
⋅
k
j
−
m
(
j
,
P
i
)
−
z
(
i
,
P
i
)
)
v
j
其中
m
(
j
,
P
i
)
=
{
∞
如果
j
∉
P
i
0
否则
(
3
)
o_i=\sum_{j\in\tilde{\mathcal{P}}_i}\exp\left(q_i\cdot k_j-m(j,\mathcal{P}_i)-z(i,\mathcal{P}_i)\right)v_j\quad\text{其中}\:m(j,\mathcal{P}_i)=\begin{cases}\infty&\text{如果}\:j\notin\mathcal{P}_i\\0&\text{否则}\end{cases}\qquad{(3)}
oi=j∈P~i∑exp(qi⋅kj−m(j,Pi)−z(i,Pi))vj其中m(j,Pi)={∞0如果j∈/Pi否则(3)
现在我们转向LSH注意力机制,可以将其理解为通过仅允许在单个哈希桶内进行注意力计算,来限制查询位置
i
i
i可以关注的目标项集合
P
i
\mathcal{P}_i
Pi。
P
i
=
{
j
:
h
(
q
i
)
=
h
(
k
j
)
}
(
4
)
\mathcal{P}_i=\{j:h(q_i)=h(k_j)\}\qquad{(4)}
Pi={j:h(qi)=h(kj)}(4)
图2(a-b)展示了完整注意力机制与哈希变体的对比示意图。图(a)描绘了完整注意力机制的注意力矩阵通常是稀疏的,但计算过程并未利用这种稀疏性。在图(b)中,查询和键已根据其哈希桶进行了排序。由于相似的项大概率会落在同一个桶中,因此可以通过仅允许在每个桶内进行注意力计算来近似完整的注意力模式。
在这种形式化中,哈希桶的大小往往不均匀,这使得跨桶的批处理变得困难。此外,桶中的查询数量和键数量可能不相等——事实上,一个桶可能包含许多查询但没有键。为了缓解这些问题,我们首先通过设置
k
j
=
q
j
∥
q
j
∥
k_j = \frac{q_j}{\|q_j\|}
kj=∥qj∥qj来确保
h
(
k
j
)
=
h
(
q
j
)
h(k_j) = h(q_j)
h(kj)=h(qj)。接着,我们按桶号对查询进行排序,并在每个桶内按序列位置排序;这定义了一种排序后的排列
i
↦
s
i
i \mapsto s_i
i↦si。在排序后的注意力矩阵中,来自同一桶的对会聚集在对角线附近(如图2c所示)。我们可以采用一种批处理方法,其中连续的
m
m
m个查询(排序后)相互关注,并关注前一个块(如图2d所示)。根据我们之前的符号表示,这对应于设置:
P
~
i
=
{
j
:
⌊
s
i
m
⌋
−
1
≤
⌊
s
j
m
⌋
≤
⌊
s
i
m
⌋
}
(
5
)
\widetilde{\mathcal{P}}_i=\left\{j:\left\lfloor\frac{s_i}{m}\right\rfloor-1\leq\left\lfloor\frac{s_j}{m}\right\rfloor\leq\left\lfloor\frac{s_i}{m}\right\rfloor\right\}\qquad{(5)}
P
i={j:⌊msi⌋−1≤⌊msj⌋≤⌊msi⌋}(5)
如果
f
max
i
∣
P
i
∣
<
m
\operatorname*{f}\operatorname*{max}_i|\mathcal{P}_i|<m
fmaxi∣Pi∣<m,则
P
i
⊆
P
~
i
\mathcal{P}_i\subseteq\widetilde{\mathcal{P}}_i
Pi⊆P
i。在实践中,我们设置
m
=
2
l
n
b
u
c
k
e
t
s
m=\frac{2l}{n_{buckets}}
m=nbuckets2l(其中
l
l
l是序列长度)。平均桶大小为
l
n
b
u
c
k
e
t
s
\frac{l}{n_{buckets}}
nbucketsl,并且我们假设一个桶增长到两倍大小的概率足够低。LSH注意力的整体过程总结在图[2]中。
多轮LSH注意力
在使用哈希时,总存在一个较小的概率,即相似的项仍然会落入不同的桶中。这种概率可以通过使用
n
r
o
u
n
d
s
n_{rounds}
nrounds个不同的哈希函数
{
h
(
1
)
,
h
(
2
)
,
…
}
\{h^{(1)}, h^{(2)}, \ldots\}
{h(1),h(2),…}进行多轮哈希来降低,具体如下:
P
i
=
⋃
r
=
1
n
r
o
u
n
d
s
P
i
(
r
)
其中
P
i
(
r
)
=
{
j
:
h
(
r
)
(
q
i
)
=
h
(
r
)
(
q
j
)
}
(
6
)
\mathcal{P}_i=\bigcup\limits_{r=1}^{n_{rounds}}\mathcal{P}_i^{(r)}\quad\text{其中}\mathcal{P}_i^{(r)}=\left\{j:h^{(r)}(q_i)=h^{(r)}(q_j)\right\}\qquad{(6)}
Pi=r=1⋃nroundsPi(r)其中Pi(r)={j:h(r)(qi)=h(r)(qj)}(6)
多轮情况本质上涉及并行执行
n
r
o
u
n
d
s
n_{rounds}
nrounds次LSH注意力;该过程的细节在附录A1中描述。
共享QK注意力的因果掩码
在Transformer解码器中,掩码(在公式3中用
m
(
j
,
P
i
)
m(j,\mathcal{P}_i)
m(j,Pi)表示)用于防止位置关注未来的信息。为了在LSH注意力中实现掩码,我们将每个查询/键向量与一个位置索引关联,使用与排序查询/键向量相同的排列重新排序位置索引,然后使用比较操作来计算掩码。
虽然不允许关注未来的信息,但Transformer的典型实现确实允许一个位置关注自身。这种行为在共享QK的设定中是不可取的,因为查询向量与自身的点积几乎总是大于查询向量与其他位置向量的点积。因此,我们修改了掩码机制,禁止一个标记关注自身,除非该标记没有其他有效的关注目标(例如,序列中的第一个标记)。
2.1 合成任务分析
为了验证LSH注意力的性能并研究其行为,我们从以下合成任务开始:复制符号序列。在这个任务中,每个训练和测试样本的形式为0w0w,其中w ∈ {1, …, N}* 是一个从1到N的符号序列(我们在实验中使用N = 127)。下面给出一个长度为3的单词w的示例。
为了研究LSH注意力,我们在上述形式的样本上训练一个语言模型,其中每个 w w w的长度为511(因此整个输入 0 w 0 w 0w0w 0w0w的长度为1024)。由于这是一个语言建模任务,我们总是根据之前的所有符号预测下一个符号,但我们会掩盖损失和准确率,只考虑输入后半部分的位置,即那些实际上可以预测的位置。
上述任务可以通过一个1层的Transformer模型完美解决(达到100%的准确率和0的损失)。需要注意的是,它需要非局部的注意力查找,因此任何依赖有限范围的稀疏注意力的模型都无法解决该任务。为了使训练简单快速,同时与NLP中使用的模型相似,我们使用一个1层的Transformer,其中 d m o d e l = d f f = 256 d_{model}=d_{ff}=256 dmodel=dff=256,并设置4个头。我们在4种不同的设置下训练了150K步:使用完整注意力、使用 n r o u n d s = 1 n_{rounds}=1 nrounds=1的LSH注意力、 n r o u n d s = 2 n_{rounds}=2 nrounds=2的LSH注意力以及 n r o u n d s = 4 n_{rounds}=4 nrounds=4的LSH注意力。
从表 Z \mathbb{Z} Z中总结的结果可以看出,使用完整注意力训练的模型可以立即用于LSH注意力,但会损失一些准确率。当从头开始使用LSH注意力训练时,使用4个哈希训练的模型也达到了几乎完美的准确率。有趣的是,当使用8个哈希进行评估时,准确率变得完美。而当使用2个或1个哈希进行评估时,准确率会下降。使用较少哈希训练的模型表现较差,但即使只使用1个哈希训练的模型,在使用8个哈希进行评估时也几乎表现完美。
3 可逆Transformer
如上节所示,如果允许近似,注意力的复杂度可以从长度的平方降低到线性。但从表 Π \Pi Π中可以看出,每个字段的开头都有一个 b ^ ⋅ n h ⋅ l \hat{b}\cdot n_h\cdot l b^⋅nh⋅l项: b ⋅ n h ⋅ l ⋅ d k b\cdot n_h\cdot l\cdot d_k b⋅nh⋅l⋅dk,或者 b ⋅ l ⋅ d m o d e l b\cdot l\cdot d_{model} b⋅l⋅dmodel的成本是无法避免的。实际上,每层之前的激活值大小已经是 b ⋅ l ⋅ d m o d e l b\cdot l\cdot d_{model} b⋅l⋅dmodel,因此整个模型的内存使用量至少为 b ⋅ l ⋅ d m o d e l ⋅ n l b\cdot l\cdot d_{model}\cdot n_l b⋅l⋅dmodel⋅nl。更糟糕的是,在Transformer的前馈层中,这一数字会上升到 b ⋅ l ⋅ d f f ⋅ n l b\cdot l\cdot d_{ff}\cdot n_{l} b⋅l⋅dff⋅nl。在大型Transformer中,通常设置 d f f = 4 K d_{ff}= 4K dff=4K和 n l = 16 n_{l}= 16 nl=16,因此当 l = 64 K l= 64K l=64K时,这将再次使用不切实际的 16 G B 16GB 16GB内存。在本节中,我们将展示如何通过使用可逆层首先解决 n l n_l nl部分的问题,然后展示分块如何帮助我们处理 d f f d_{ff} dff问题。每种方法对内存和时间复杂度的影响总结在表[3]中。
RevNets
可逆残差网络由
Gomez et al.
(
2017
)
\sqrt{\text{Gomez et al.}(2017)}
Gomez et al.(2017)提出,研究表明它们可以替代ResNets用于图像分类。其主要思想是允许从后续层的激活值中恢复任何给定层的激活值,仅使用模型参数。与必须为反向传播检查点中间值不同,随着反向传播从网络输出到输入进行,层可以逐个反转。普通的残差层执行一个函数
x
↦
y
x\mapsto y
x↦y,该函数对单个输入进行操作并产生单个输出,形式为
y
=
x
+
F
(
x
)
y=x+F(x)
y=x+F(x),而可逆层则对输入/输出对进行操作:
(
x
1
,
x
2
)
↦
(
y
1
,
y
2
)
(x_{1},x_{2})\mapsto(y_{1},y_{2})
(x1,x2)↦(y1,y2),并遵循以下公式:
y
1
=
x
1
+
F
(
x
2
)
y
2
=
x
2
+
G
(
y
1
)
y
1
=
x
1
+
F
(
x
2
)
y
2
=
x
2
+
G
(
y
1
)
(
7
)
y1=x1+F(x2)y2=x2+G(y1)y_1=x_1+F(x_2)\quad y_2=x_2+G(y_1)\qquad{(7)}
y1=x1+F(x2)y2=x2+G(y1)y1=x1+F(x2)y2=x2+G(y1)(7)
一个层可以通过减去(而不是添加)残差来反转:
x
2
=
y
2
−
G
(
y
1
)
x
1
=
y
1
−
F
(
x
2
)
(
8
)
x_2=y_2-G(y_1)x_1=y_1-F(x_2)\qquad{(8)}
x2=y2−G(y1)x1=y1−F(x2)(8)
可逆Transformer
我们将RevNet的思想应用于Transformer,将注意力和前馈层结合在revnet块中。在上面的符号中,F成为注意力层,而G成为前馈层。请注意,层归一化(Ba et al., 2016)被移到了残差块内部。
Y
1
=
X
1
+
A
t
t
e
n
t
i
o
n
(
X
2
)
Y
2
=
X
2
+
F
e
e
d
F
o
r
w
a
r
d
(
Y
1
)
(
9
)
Y_1=X_1+\mathrm{Attention}(X_2)\quad Y_2=X_2+\mathrm{FeedForward}(Y_1)\qquad{(9)}
Y1=X1+Attention(X2)Y2=X2+FeedForward(Y1)(9)
可逆Transformer的优势
可逆Transformer不需要在每一层存储激活值,因此消除了 n l n_l nl项。在第[5]节中,我们展示了在使用相同参数数量的情况下,它的性能与普通Transformer相同;我们通过使 x 1 x_1 x1和 x 2 x_2 x2的尺寸均为 d m o d e l d_{model} dmodel来实现这一点。
分块(Chunking)
虽然可逆性解决了 n l n_l nl项的问题,但更厚的层仍然可能占用大量内存。特别是前馈层,可能会使用维度为 d f f = 4 K d_{ff}=4K dff=4K或更高的中间向量。然而,前馈层中的计算在序列中的各个位置之间是完全独立的,因此可以将计算分成 c c c个块:
Y
2
=
[
Y
2
(
1
)
;
…
;
Y
2
(
c
)
]
=
[
X
2
(
1
)
+
FeedForward
(
Y
1
(
1
)
)
;
…
;
X
2
(
c
)
+
FeedForward
(
Y
1
(
c
)
)
]
(
10
)
Y_2=\left[Y_2^{(1)};\ldots;Y_2^{(c)}\right]=\left[X_2^{(1)}+\text{FeedForward}(Y_1^{(1)});\ldots;X_2^{(c)}+\text{FeedForward}(Y_1^{(c)})\right]\qquad{(10)}
Y2=[Y2(1);…;Y2(c)]=[X2(1)+FeedForward(Y1(1));…;X2(c)+FeedForward(Y1(c))](10)
通常情况下,该层通过对所有位置并行执行操作来进行批处理,但一次只操作一个块可以减少内存使用。公式8中的反向计算和反向传播也被分块处理。除了前馈层外,对于具有大词汇表(超过
d
m
o
d
e
l
d_{model}
dmodel个词类型)的模型,我们还会对输出的对数概率进行分块,并一次计算部分序列的损失。
分块、大批量和参数重用
通过分块和可逆层,我们在整个网络中用于激活值的内存与层数无关。然而,参数的数量仍然会随着层数的增加而增长。这个问题可以通过在某一层不进行计算时将其参数交换到CPU内存来解决。在标准Transformer中,这种操作效率较低,因为向CPU传输内存的速度较慢。但在Reformer中,批量大小乘以序列长度的值要大得多,因此使用参数进行的计算量可以分摊传输成本。
4 相关工作
由(Vaswani et al., 2017)提出的Transformer模型已广泛应用于自然语言任务,并进一步扩展到建模多种数据,如乐谱(Huang et al., 2018)和图像(Parmar et al., 2018; Ramachandran et al., 2019)。最值得注意的是,这类模型在极大规模语言模型的自我监督训练中取得了成功(Devlin et al., 2018; Radford et al., 2019)。
鉴于最先进的序列模型对计算的巨大需求,人们越来越关注如何减少Transformer模型的内存占用和计算需求。除了精度降低和梯度检查点(Sohoni et al., 2019)等标准方法外,最近还探索了Transformer模型自注意力机制的更高效版本(Sukhbaatar et al., 2019a;b)。
特别是,利用注意力层中的稀疏性已被证明是有效的。OpenAI提出了稀疏Transformer(Child et al., 2019),它利用了注意力的因子化稀疏表示。使用产品键注意力来增加键空间也被用于减少前馈层的内存需求,而不会损失性能(Lample et al., 2019)。
据我们所知,局部敏感哈希(LSH)此前并未直接应用于Transformer的注意力层。但之前使用外部记忆的神经网络研究已经处理过大规模的记忆问题。记忆网络的最初实现(Weston et al., 2014)以及后续的扩展工作(Bordes et al., 2015; Chandar et al., 2016)使用了数百万规模的记忆。这样做的代价是记忆必须在训练前固定。此外,由于在训练初期模型不太可能正确查询记忆,因此使用强监督来鼓励模型查询有用的记忆位置。这些提示要么由任务提供额外的监督信息,要么像Hill et al. (2015)中那样通过启发式方法确定。Santoro et al. (2016)消除了记忆必须在训练前固定的要求,但代价是记忆大小,后来Rae et al. (2016)缓解了这一问题。最后一篇论文考虑了使用近似最近邻(包括LSH和随机kd树)进行记忆查找,但仅用于外部记忆的查找。
5 实验
在本节中,我们展示了上述技术的实验结果。我们逐一分析这些技术,以明确哪些组合对性能有影响。我们首先展示可逆层和共享查询-键空间不会影响性能,然后分析哈希注意力,最后分析完整的Reformer模型。
我们在imagenet64和enwik8-64K任务上进行了实验,其中后者是enwik8的一个变体,被分块为 2 16 = 64 K 2^{16}=64K 216=64K个标记的子序列。我们使用3层模型进行消融实验,以便与常规Transformer进行比较,后者内存使用率高且执行完整的 O ( l 2 ) O(l^{2}) O(l2)注意力计算。所有实验的配置为 d m o d e l = 1024 d_{model}=1024 dmodel=1024, d f f = 4096 d_{ff}=4096 dff=4096, n h e a d s = 8 n_{heads}=8 nheads=8,总批量大小为8个序列。我们使用Adafactor优化器(Shazeer & Stern, 2018)训练这些模型。我们还评估了WMT 2014英德翻译任务,遵循Vaswani et al. (2017)的超参数设置。所有实验的训练均在8个设备(8个GPU或8个TPU v3核心)上并行进行。训练代码已公开提供。
共享QK的影响
我们首先考虑共享QK注意力对常规Transformer模型的影响。共享QK注意力设置 k j = q j ∥ q j ∥ k_j=\frac{q_j}{\|q_j\|} kj=∥qj∥qj,并防止标记关注自身(除非没有其他上下文可用)。在图B!的左侧,我们绘制了常规注意力和共享QK注意力的困惑度曲线。共享查询-键空间的表现并不比常规注意力差;事实上,在enwik8上,它的训练速度似乎略快。换句话说,切换到共享QK注意力并不会牺牲准确性。
可逆层的影响
在图[3]右侧的两个图中,我们比较了常规Transformer(Vaswani et al., 2017)与第[3]节中描述的可逆Transformer。这两个模型的参数数量相同,学习曲线也几乎相同。这些结果表明,可逆Transformer的内存节省并不会以准确性为代价。
机器翻译中的可逆层
我们还在英德机器翻译的编码器-解码器Transformer模型中评估了可逆层。我们首先在Transformer-base架构中使编码器和解码器完全可逆,发现训练100K步后,该模型的表现与Vaswani et al. (2017)相当。我们还评估了更多训练步数和更大模型的情况。Reformer模型非常节省内存,因此在后两个实验中,我们不需要通过共享嵌入和输出投影权重矩阵来节省内存。结果如表4所示。我们没有在此设置中应用LSH注意力,因为样本是单句,而句子通常较短。我们的典型LSH注意力配置在哈希和排序后使用128个标记的块,而WMT14测试集中的样本都短于128个标记。
Transformer中的LSH注意力
LSH注意力是对完整注意力的近似,如图4所示,随着哈希数量的增加,其准确性逐渐提高。当 n r o u n d s = 8 n_{rounds}=8 nrounds=8时,它已经几乎与完整注意力相匹配。模型的计算成本随着哈希数量的增加而增加,因此可以根据计算预算调整此超参数。此外,如表2所示,可以在评估时增加哈希数量以生成更准确的结果。在图5的右侧,我们绘制了不同注意力类型的速度与序列长度的关系,同时保持总标记数不变。我们发现,常规注意力在序列长度增加时变得更慢,而LSH注意力的速度保持稳定。
大型Reformer模型
为了验证Reformer确实可以在单个核心上拟合大型模型并在长序列上快速训练,我们在enwik8和imagenet64上训练了多达20层的大型Reformer模型。如图5所示,这些模型可以放入内存并训练。我们无法在这种情况下训练Transformer基线,因为它们速度太慢且内存需求过高,但我们看到随着层数的增加,性能有明显提升。在enwik8上,一个12层模型训练20K步,dropout率为0.1,在测试集上达到了1.19 bits/dim。我们还通过进一步调优和改进训练了一个12层Reformer模型,并在enwik8测试集上达到了1.05 bits/dim。
6 结论
Reformer结合了Transformer的建模能力与一种能够在长序列上高效执行且内存占用小的架构,即使对于层数较多的模型也是如此。我们相信,这将有助于大型、参数丰富的Transformer模型变得更加普及和易于使用。此外,处理长序列的能力为Reformer在许多生成任务中的应用开辟了道路。除了生成非常长的连贯文本外,Reformer还可以将Transformer模型的强大能力应用于其他领域,如时间序列预测、音乐、图像和视频生成。
多轮LSH注意力
在本节中,我们将更详细地描述多轮哈希版本的LSH注意力机制。我们首先重复正文中的公式(3),该公式描述了稀疏注意力的一般形式:
o
i
=
∑
j
∈
P
i
exp
(
q
i
⋅
k
j
−
m
(
j
,
P
i
)
−
z
(
i
,
P
i
)
)
v
j
where
m
(
j
,
P
i
)
=
{
∞
if
j
∉
P
i
0
otherwise
(
3
)
o_i = \sum_{j \in P_i} \exp(q_i \cdot k_j - m(j,P_i) - z(i,P_i)) v_j \quad \text{where } m(j,P_i) = \begin{cases} \infty & \text{if } j \notin P_i \\ 0 & \text{otherwise} \end{cases}\qquad{(3)}
oi=j∈Pi∑exp(qi⋅kj−m(j,Pi)−z(i,Pi))vjwhere m(j,Pi)={∞0if j∈/Piotherwise(3)
在多轮情况下,查询位置
i
i
i可以关注由公式(6)定义的键位置
P
i
P_i
Pi,我们在此也重复该公式:
P
i
=
⋃
r
=
1
n
rounds
P
i
(
r
)
where
P
i
(
r
)
=
{
j
:
h
(
r
)
(
q
i
)
=
h
(
r
)
(
q
j
)
}
(
6
)
P_i = \bigcup_{r=1}^{n\text{rounds}} P_i^{(r)} \quad \text{where } P_i^{(r)} = \{j : h^{(r)}(q_i) = h^{(r)}(q_j)\}\qquad{(6)}
Pi=r=1⋃nroundsPi(r)where Pi(r)={j:h(r)(qi)=h(r)(qj)}(6)
为了支持批处理,注意力计算在排序后的查询/键块上进行:
P
~
i
(
r
)
=
{
j
:
∣
s
i
(
r
)
m
∣
−
1
≤
s
j
(
r
)
m
≤
∣
s
i
(
r
)
m
∣
}
(
11
)
\tilde{P}_i^{(r)} = \{j : \left| \frac{s_i^{(r)}}{m} \right| - 1 \leq \frac{s_j^{(r)}}{m} \leq \left| \frac{s_i^{(r)}}{m} \right| \}\qquad{(11)}
P~i(r)={j:
msi(r)
−1≤msj(r)≤
msi(r)
}(11)
结合公式(3)和公式(6)得到:
o
i
=
∑
j
∈
P
~
i
exp
(
q
i
⋅
k
j
−
m
(
j
,
P
i
)
−
z
(
i
,
P
i
)
)
v
j
(
12
)
o_i=\sum_{j\in\tilde{\mathcal{P}}_i}\exp\left(q_i\cdot k_j-m(j,\mathcal{P}_i)-z(i,\mathcal{P}_i)\right)v_j\qquad{(12)}
oi=j∈P~i∑exp(qi⋅kj−m(j,Pi)−z(i,Pi))vj(12)
= ∑ r = 1 n r o u n d s exp ( z ( i , P i ( r ) ) − z ( i , P i ) ) ∑ j ∈ P ~ i ( r ) 1 N i , j exp ( q i ⋅ k j − m ( j , P i ( r ) ) − z ( i , P i ( r ) ) ) v j ( 13 ) \begin{aligned} =\sum_{r=1}^{n_{rounds}}\exp\left(z(i,\mathcal{P}_i^{(r)})-z(i,\mathcal{P}_i)\right)\sum_{j\in\widetilde{\mathcal{P}}_i^{(r)}}\frac{1}{N_{i,j}}\exp\left(q_i\cdot k_j-m(j,\mathcal{P}_i^{(r)})-z(i,\mathcal{P}_i^{(r)})\right)v_j \end{aligned}\qquad{(13)} =r=1∑nroundsexp(z(i,Pi(r))−z(i,Pi))j∈P i(r)∑Ni,j1exp(qi⋅kj−m(j,Pi(r))−z(i,Pi(r)))vj(13)
= ∑ r = 1 n r o u n d s exp ( z ( i , P i ( r ) ) − z ( i , P i ) ) o i ( r ) ( 14 ) =\sum_{r=1}^{n_{rounds}}\exp\left(z(i,\mathcal{P}_i^{(r)})-z(i,\mathcal{P}_i)\right)o_i^{(r)}\qquad{(14)} =r=1∑nroundsexp(z(i,Pi(r))−z(i,Pi))oi(r)(14)
o i ( r ) = ∑ j ∈ P ~ i ( r ) exp ( q i ⋅ k j − m i , j ( r ) − z ( i , P i ( r ) ) ) v j ( 15 ) o_i^{(r)}=\sum_{j\in\tilde{\mathcal{P}}_i^{(r)}}\exp\left(q_i\cdot k_j-m_{i,j}^{(r)}-z(i,\mathcal{P}_i^{(r)})\right)v_j\qquad{(15)} oi(r)=j∈P~i(r)∑exp(qi⋅kj−mi,j(r)−z(i,Pi(r)))vj(15)
w h e r e N i , j = ∣ { r ′ : j ∈ P i ( r ′ ) } ∣ a n d m i , j ( r ) = { ∞ i f j ∉ P i ( r ) 1 0 5 i f i = j log N i , j o t h e r w i s e ( 16 ) \mathrm{where}N_{i,j}=\left|\left\{r^{\prime}:j\in\mathcal{P}_i^{(r^{\prime})}\right\}\right|\mathrm{and}m_{i,j}^{(r)}= \begin{cases} \infty & \mathrm{if}j\notin\mathcal{P}_i^{(r)} \\ 10^5 & \mathrm{if}i=j \\ \log N_{i,j} & \mathrm{otherwise} & \end{cases}\qquad{(16)} whereNi,j= {r′:j∈Pi(r′)} andmi,j(r)=⎩ ⎨ ⎧∞105logNi,jifj∈/Pi(r)ifi=jotherwise(16)
每一轮LSH注意力都会生成一个向量 o i ( r ) o_i^{(r)} oi(r),该向量可以独立于其他轮次计算,除了引入一个 N i , j N_{i,j} Ni,j项以避免在构建 P i ( r ) \mathcal{P}_i^{(r)} Pi(r)集合的并集时重复计算元素。在我们的实现中,我们将 N i , j N_{i,j} Ni,j因子合并到掩码项 m i , j ( r ) m_{i,j}^{(r)} mi,j(r)中。
我们还修改了 m i , j ( r ) m_{i,j}^{(r)} mi,j(r),为 i = j i=j i=j的情况引入了一个特殊处理。添加此情况是因为标准Transformer中的因果掩码允许位置 i i i关注自身,而这在共享QK的设定中是不可取的。我们将掩码设置为一个较大但有限的值,以禁止自身注意力,除非某个标记没有其他有效的注意力目标。例如,序列中的第一个标记只能关注自身,因为没有先前的上下文可用。