题目:训练面向数据高效的ViT和使用注意力的知识蒸馏
提示:本篇文章是ViT经典解读系列的一个重要内容,解决原始的ViT依赖数据的短板,并结合知识蒸馏,将注意力灵活运用起来,值得参考与借鉴。
如果不了解Transformer,推荐阅读: 深入解读Vision Transformer:拒绝做半瓢水
摘要
最近,完全基于注意力的神经网络被用来解决图像理解任务,例如图像分类。这些高性能的ViT在成百上千万幅图像上预训练,使用了一个很大的基础设施,因此限制了他们的适用性。
在这篇文章中,我们产生了一个有竞争力的无CNN的Transformer,仅仅在ImageNet上训练。我们在一台单计算机上花了不到三天时间去训练他们。我们所指的ViT(86M参数)在ImageNet上获取了Top1的83.1%识别率(单裁剪,没有额外的数据)。
更重要的是,我们提出了一个老师-同学策略针对Transformer。它依赖一个蒸馏令牌,保证这个学生能够通过注意力从老师学习。我们展示了这种基于令牌的蒸馏的趣味,特别当使用ConvNet作为老师。这使得我们在ImageNet和迁移到其他任务上,获得与ConvNet相当的结果。代码和模型已开源。
1 介绍
卷积神经网络已经成为图像理解任务的主要设计范式,因为最开始在图像分类任务证明。他们获得成功的因素之一是大量训练数据集的可获取性,即ImageNet。受到NLP中基于注意力模型所获得的成功启发,越来越多的人将注意力机制应用到到CNN内的结构感兴趣,如SENet等。最近,一些研究人员提出混合结构将Transformer移植到CNN中解决视觉问题。
Vaswani 提出将Transformer应用到视觉中,是直接使用NLP中的一个架构,但是将图像块作为输入来解决图像分类任务。他们的文章展示了ViT在大规模数据集(JFT-300M)上预训练后取得了满意的结果。这篇文章总结到“Transformer不能很好泛化,当在数据不充分的时候”,并且训练这个模型需要海量计算资源的参与。
在这篇文章中,我们在一个8GPU的单节点上训练一个ViT,花费53小时预训练,和20h微调,就能获得与ConvNet相当的性能,他们的参数量和高效性差不多的时候。它仅使用ImageNet作为单独的训练集。我们建立在原始的ViT的基础上,和在timm库上改进。使用我们面向数据高效的ViT,即DeiT,我们展示了比先前结果一个很大的提升,如下图。我们的腐蚀实验详细了展示了超参和关键因素,例如重复增强。
我们解决了另一个问题:如何蒸馏这些模型?我们引入了一个基于令牌的策略,针对Transformer并用Deit D D D表示,展示了它替代常用蒸馏的优越性。
总结起来,我们工作作出了以下贡献:
- 我们展示了我们不含有卷积的网络结构能够获取相当的结果,在ImageNet上没有额外数据下与当前最先进的方法相比较。他们是在4GPU的单节点上训练三天学习到的,我们其他两个新的模型如DeiT-S和Deit-Ti包含更少的参数,可以看成是ResNet-50和ResNet-18的对应物。
- 我们引入了一个新的蒸馏过程,基于蒸馏令牌,它扮演着class token一样的角色,除了它旨在复现老师估计的label。这两种tokens通过注意力来交互。这个针对transformer的策略超过了原始的蒸馏一大截。
- 有意思的是,使用我们的蒸馏,ViT从ConvNet学习到更多,相对于从另一个相当的Transformer中。
- 我们模型在ImageNet上预训练后,迁移到在其他下游任务如细分的分类如CIFAR-10/100,Oxford-102 flowers,Stanford Cars等上也能获得相当的性能,
这篇文章的组织结构如下:第二部分阅读了相关工作;第三部分关注面向图像分类的transformer;第四部分介绍我们transformer的蒸馏策略;第五部分,实验将CNNs和最近的transformers,同样包含我们蒸馏的transformer进行了对比与分析;第六部分介绍我们的训练策略。它包含了我们DeiT数据高效训练的广泛消融,这对DeiT中关键成分提供了一些简洁。第七部分总结。
2 相关工作
Image Classification。
图像分类是计算机视觉的核心,以至于常作为一个基准来衡量图像理解的进步。任何的进步通常会转化为其他相关任务的进步,如检测和分割。自从2012年AlexNet,CNN已经统治了这个基准,并且成为事实上的标准。ImageNet上最领先的方法的发展反应了CNN架构和学习的进展。
尽管有不少尝试去使用Transformer实现图像分类,直到最近他们的性能不如CNN。然而,混合的架构将CNN和transformer结合起来,包括自注意力机制,已经在分类,检测,分割,视频处理,无监督的目标发掘和统一的文本视频任务等上展示了可竞争的结果。
最近,ViT不使用任何的卷积结构,已经在ImageNet上追上了与最领先方法的差距。这个性能是令人难忘的,因为CNN的方法已经从多年的调试和优化受益匪浅。然后,基于ViT的研究,需要大量的数据上进行预训练才能实现迁移学习的有效性。在这篇文章中,我们实现了仅仅依赖ImageNet-1K就能获得很强的性能。
Transformer结构
Transformer架构,最早来源于机器翻译中的Vaswani结构,当前成为了所有NLP任务的参考模型。对于图像分类中的卷积网络的性能提升,也被transformer所启发。例如,SENet,SKNet和ResNeSt探索类似于Transformer中自注意力机制。
Knowledge Distillation,KD知识蒸馏
原文:Knowledge Distillation (KD), introduced by Hinton et al. [24], refers to the training paradigm in which a student model leverages “soft” labels coming from a strong teacher network. This is the output vector of the teacher’s softmax function rather than just the maximum of scores, wich gives a “hard” label. Such a training improves the performance of the student model (alternatively , it can be regarded as a form of compression of the teacher model into a smaller one – the student). On the one hand the teacher’s soft labels will have a similar effect to labels smoothing [58]. On the other hand as shown by Wei et al. [54] the teacher’s supervision takes into account the effects of the data augmentation, which sometimes causes a misalignment between the real label and the image. For example, let us consider image with a “cat” label that represents a large landscape and a small cat in a corner. If the cat is no longer on the crop of the data augmentation it implicitly changes the label of the image. KD can transfer inductive biases [1] in a soft way in a student model using a teacher model where they would be incorporated in a hard way . For example, it may be useful to induce biases due to convolutions in a transformer model by using a convolutional model as teacher. In our paper we study the distillation of a transformer student by either a convnet or a transformer teacher. We introduce a new distillation procedure specific to transformers and show its superiority.
被Hinton介绍的知识蒸馏,指的是一种训练范式,其中,一个学生模型借助从很强的老师模型中产生的软标签。这个老师模型的softmax函数的输出向量,而不仅仅是给出硬标签的最大的得分。这样的一种训练方式提升了学生模型的性能。或者说,它可以看作是老师模型到一个更小的学生模型的压缩。一方面,老师模型产生的软标签和标签平滑(label smoothing)相似的效果。另外一方面如Wei方法展示,老师模型的有监督考虑数据增强的效果,这有时候会导致真实标签与图像之间的错位。例如,让我们考虑一个带有“猫”标签的图像,其中大部分是草坪,而一只小猫在角落。如果这个猫不在裁剪后的照片中,这将直接改变图像的标签。知识蒸馏能够使用教师以一种软方式迁移归纳偏置到学生模型中,而原本是一种硬方式来合并的。例如,通过卷积模型作为老师模型来诱导迁移模型中的卷积偏置是有效的。在我们的文章中,我们研究了一中蒸馏方式,即transformer作为学生模型,CNN或Transformer作为老师模型,我们引入了一个新的蒸馏过程,针对transformer,并且展示它的优越性。
这篇文章对知识蒸馏,讲得非常透彻;同时还对Transformer的回顾,也很精彩。继续翻译理解。
3 ViT:回顾
在这部分,我们简单回顾与ViT相关的基础,并进一步讨论位置编码和分辨率。
Multi-head self-attention layers (MSA),多头注意力
注意力机制基于一种可训练的带有(键,值)向量对的联想记忆。使用内积将查询向量
q
u
e
r
y
∈
R
d
query \in R^{d}
query∈Rd与一组键值向量
k
k
k(
k
∈
R
k
×
d
k \in R^{k \times d}
k∈Rk×d)相匹配。这个内积然后被缩放,和使用softmax函数进行归一化,去获取
k
k
k个权重。这个注意力的输出是一组
k
k
k个值向量的加权和。对于长度为N的查询向量(
Q
∈
R
N
×
d
Q \in R^{N \times d}
Q∈RN×d)构成的序列,将产生一个大小
N
×
d
N \times d
N×d的输出矩阵:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
/
d
)
V
Attention(Q, K, V) = Softmax(QK^{T}/\sqrt{d})V
Attention(Q,K,V)=Softmax(QKT/d)V
其中,Softmax函数用在输入举证的每一列上,
d
\sqrt{d}
d项为了合适的归一化。
在ViT中提出了一个自注意力层。询问,键和值矩阵是N个输入向量(
X
∈
R
N
×
D
X \in R^{N \times D}
X∈RN×D)构成的序列与线性变换
W
Q
,
W
K
,
W
V
W_{Q}, W_{K}, W_{V}
WQ,WK,WV想计算的结果:
Q
=
X
W
Q
,
K
=
X
W
K
,
V
=
X
W
V
Q = XW_{Q}, K = XW_{K}, V = XW_{V}
Q=XWQ,K=XWK,V=XWV
其中,添加限制
k
=
N
k=N
k=N,表示输入向量之间的所有注意力都计算了。
Transformer block for image,针对图像的Transformer模块
为了得到ViT中的完整的Transformer模块,我们在MSA模块顶部添加了一个前馈网络,即FeedForward Network。FFN主要由两个线性层构成,被一个GeLu激活层所分割。第一个线性层将维度从D扩展到4D,第二层将4D维度缩小到D。MSA和FFN层都执行残差操作由于skip连接和一个Layer Normalization。
为了获得一个Transformer去处理图像,我们的工作基于ViT模型去构造。ViT是一个简单的优雅的架构能够处理图像输入,图像就是是一个输入令牌的序列一样。固定尺度的RGB图像被分解上N个图像块,每一个图像块是由 16 × 16 16 \times 16 16×16个像素构成的块状。 N = 14 × 14 N = 14 \times 14 N=14×14。每一个块被一个线性层投影并保持其整体维度为 3 × 16 × 16 3 \times 16 \times 16 3×16×16。
Transformer模块如上所述,对于图像块嵌入的顺序是不变的,因此不用考虑他们的相对位置。位置信息被合并为固定的或者可训练的位置嵌入。他们在进入第一个transformer模块前,添加到patch嵌入中,然后输入到transformer模块栈中。
The class token,标签令牌
标签令牌是一个可训练的向量,在输入第一层前添加到patch嵌入中,然后通过transformer层和一个线性层投影去预测分类。这个标签令牌从NLP继承,并背离了计算机视觉中常用的池化层去预测标签。因此,Transformer处理(N+1)个块,其中N个令牌块,一个label token。其中,只有类向量用于预测输出。这个结构迫使自注意力在patch tokens和class token之间去传播信息:在训练过程中,这个有监督的信息仅仅来自于class embedding,而patch tokens是这个模型的唯一变量输入。
Fixing the positional encoding across resolutions。
Touvron等展示在一个低的分辨率上训练然后在高分辨率去微调时值得期待的。这加速了整个训练的而过程,和提升了精度,在流行的数据增强方法下。当增加一张输入图像的分辨率,我们保持每个块的尺度不变,那么patch块的个数将增加。由于transformer模块的结构和标签令牌的原因,不需要改变模型和分类器去处理这些tokens。反而,需要调整位置嵌入,因为有N个patch,每一个patch都有一个位置信息。ViT在改变分辨率时,对位置编码进行了插值,并证明该方法适用于随后的微调阶段。
Distillation through attention
在这部分,我们建设我们可以获取一个强大的图像分类器作为老师模型。它可以是一个ConVNet或者混合的分类器。我们解决如何从老师模型学习一个Transformer。如同我们将在第五部分看到,通过比较精度与输出之间的平衡,通过Transformer来替换CNN是有效的。这部分涵盖蒸馏的两个方向:即硬蒸馏hard distillation和软蒸馏,已经经典蒸馏与令牌蒸馏。
Soft distillation,软蒸馏
软蒸馏最小化教师模型与学生模型的softmax之间的KL散度。假设教师模型的输出为
Z
t
Z_{t}
Zt,学生模型的输出为
Z
s
Z_{s}
Zs。蒸馏的温度用
τ
\tau
τ表示,
λ
\lambda
λ表示平衡KL散度损失
K
L
KL
KL与在真实标签
y
y
y上的交叉熵
L
C
E
L_{CE}
LCE的系数,
φ
\varphi
φ表示softmax函数。蒸馏目标函数定义为:
L
g
l
o
b
a
l
=
(
1
−
λ
)
L
C
E
(
φ
(
Z
s
)
,
y
)
+
λ
τ
2
K
L
(
φ
(
Z
s
/
τ
)
,
φ
(
Z
t
/
τ
)
)
L_{global} = (1-\lambda)L_{CE}(\varphi(Z_{s}),y) + \lambda\tau ^{2}KL(\varphi(Z_{s}/\tau),\varphi(Z_{t}/\tau))
Lglobal=(1−λ)LCE(φ(Zs),y)+λτ2KL(φ(Zs/τ),φ(Zt/τ))
Hard-label distillation,硬蒸馏
我们介绍蒸馏的一个变体,我们将老师模型的决策作为一个真实的标签。假设老师模型的决策表示为:
y
t
=
a
r
g
m
a
x
c
Z
t
(
c
)
y_{t} = argmax_cZ_t(c)
yt=argmaxcZt(c)。与这个硬标签蒸馏相关的目标函数定义为:
L
g
l
o
b
a
l
h
a
r
d
D
i
s
t
i
l
l
=
1
2
L
C
E
(
φ
(
Z
s
)
,
y
)
+
1
2
L
C
E
(
φ
(
Z
s
)
,
y
t
)
L_{global}^{hardDistill} = \frac{1}{2}L_{CE}(\varphi(Z_{s}),y) + \frac{1}{2}L_{CE}(\varphi(Z_{s}),y_{t})
LglobalhardDistill=21LCE(φ(Zs),y)+21LCE(φ(Zs),yt)
对于给定的图像,教师模型的硬标签可能会根据特定的数据增强而改变。我们将会发现这个选择比传统的更好,同时是无参数的并且概念上更简单:老师模型的预测
y
t
y_{t}
yt与真实标签扮演着一样的角色。
同样注意到硬标签能够通过label smoothing转化为软标签,其中真实的标签看成有一个
(
1
−
ε
)
(1-\varepsilon )
(1−ε)概率,剩余的
ε
\varepsilon
ε在剩下的类中共享。我们固定这个参数
ε
=
0.1
\varepsilon=0.1
ε=0.1在我们使用真实标签的所有实验中。
Figure : Our distillation procedure: we simply include a new distillation token. It interacts with the class and patch tokens through the self-attention layers. This distillation token is employed in a similar fashion as the class token, except that on output of the network its objective is to reproduce the (hard) label predicted by the teacher, instead of true label. Both the class and distillation tokens input to the transformers are learned by back-propagation.
我们的蒸馏过程:我们仅包含了一个新的蒸馏令牌。它通过自注意力层与class和patch令牌交互。这个蒸馏令牌与class token的使用方式差不多,除了在网络输出上它的目标函数是复现老师模型预测的硬标签,而不是真实的标签。输入到transformer中的类和整理令牌都是通过反向传播的方式学习的。
Distillation token蒸馏令牌
We add a new token, the distillation token, to the initial embeddings (patches and class token). Our distillation token is used similarly as the class token: it interacts with other embeddings through self-attention, and is output by the network after the last layer. Its target objective is given by the distillation component of the loss. The distillation embedding allows our model to learn from the output of the teacher, as in a regular distillation, while remaining complementary to the class embedding.
在这部分我们关注我们提出的蒸馏过程,如上图所示。我们添加了一个新的令牌,即蒸馏令牌到原始的嵌入(包含class和patch token)中。我们的蒸馏令牌与class token使用差不多:它通过自注意力与其他嵌入交互,并且在最后一层由网络输出。它的目标函数由损失的蒸馏部分给出。这个蒸馏嵌入使得我们的模型能够从老师模型的输出中学习,正如常规的蒸馏,同时保留与类嵌入的互补。
有意思的一点是,我们观察到学习到的类和蒸馏令牌向不同的向量收敛:这些标记之间的余弦相似度等于0.06。随着每一层计算class和distillation嵌入,他们通过网络逐渐变得更加相似,一直到他们最后一层具有cos=0.93的高相似度,但仍然低于1。这是预期的,因为他们的目标是产生相似但不是相同的目标。
我们验证了我们的蒸馏令牌向模型里添加了一些东西,而不是简单地添加一个与目标标签相关联的附件类标签:我们在没有使用老师模型的伪标签,而是带有两个类标签的transformer做实验。即使我们独立地随机初始化他们,在训练过程中,他们向两个相同的向量收敛,且(cos=0.999)。并且输出的嵌入也是准相同的。这个添加的类令牌没有引入任何东西去提升性能。反而,我们的蒸馏令牌在原始的ViT模型带来显著的提升,在实验部分将会验证。