Bootstrap

Swin Transformer论文解读


论文: 《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》
代码: https://github.com/microsoft/Swin-Transformer

创新点

Transformer从语言应用到视觉,主要有两大挑战:
1、图像中目标尺寸变化大;
2、与文本中单词量相比,图像中像素数量更大;
为了解决这些差异,我们提出了一种分层Transformer,其特征是用 Shifted windows 计算的。移位窗口方案通过将 self-attention 计算限制在不重叠的本地窗口上,同时还允许跨窗口连接,从而带来更高的效率;与图像大小相关的的线性复杂度;
性能:
在分类任务ImageNet-1k上,top1 accuracy达到87.3%;在COCO test数据集上检测性能达到58.7AP,分割性能51.1AP;

算法

Swin Transformer结构如图3a所示,
在这里插入图片描述
Swin Transformer流程如下:
1、输入 ( N , 3 , 224 , 224 ) (N, 3, 224, 224) (N,3,224,224)图像,经过Patch Partition以及Linear Embedding输出为 ( N , 96 , 56 ∗ 56 ) (N, 96, 56*56) (N,96,5656,此处C=96,(该过程通过kernal为patch size=4卷积实现);
2、进入stage1,Swin Transformer Block结构如图3b,主要包括输入W-MSA及SW-MSA。首先进行进行归一化,window partition,window size=7,输出为维度为 ( 64 N , 7 , 7 , 96 ) (64N,7,7,96) (64N,7,7,96);该结果经过attention模块(W-MSA)+FFN;(64个window)(MSA结构可以参考之前文章Transformer结构解读
3、上述输出经过归一化,x进行偏移window size//2大小,而后进行window partition,window size=7,以及attention模块(SW-MSA)+FFN;
4、上述输出经过patch merging降低分辨率;进入stage2阶段,以此类推;

Patch Merging

为了生成多级分辨率,随着网络加深,通过Patch Merging层实现;
第一个Patch Merging,stage2阶段,对于 ( C , W / 4 , H / 4 ) (C,W/4,H/4) (CW/4H/4)的输入,将2*2个patch归为一组,channel 变为4C,分辨率变为之前1/4,输出为 ( 4 C , W / 8 , H / 8 ) (4C,W/8,H/8) (4C,W/8,H/8),其过程类似feat reshape操作;而后降维为2C输出 ( 2 C , W / 8 , H / 8 ) (2C,W/8,H/8) (2C,W/8,H/8)
stage3、stage4输出分辨率分别为 ( 4 C , W / 16 , H / 16 ) , ( 8 C , W / 32 , H / 32 ) (4C,W/16,H/16),(8C,W/32,H/32) (4C,W/16,H/16),(8C,W/32,H/32).

W-MSA

全局self-attention计算成本高,使用window based self-attendtion可降低计算量。如式1为全局attention,式2为使用窗口attention计算复杂度;
self-attention的计算瓶颈在于全图匹配QK,改为窗口内,可大大降低计算量,如式2。
在这里插入图片描述
式1,MSA计算复杂度分析如下,
1、MSA生成三个特征向量Q,K,V过成: Q = x ∗ W Q , K = x ∗ W K , V = x ∗ W V Q=x*W^Q,K=x*W^K,V=x*W^V Q=xWQ,K=xWK,V=xWV 。x​的维度是 ( h w , C ) (hw,C) (hw,C),W​的维度是 ( C , C ) (C,C) (C,C)​,那么这三项的计算复杂度是​ 3 h w C 2 3hwC^2 3hwC2
2、Attention中计算 Q K T QK^T QKT:Q,K,V的维度是 ( h w , C ) (hw,C) (hw,C) ,因此该过程的复杂度是 ( h w ) 2 C (hw)^2C (hw)2C
3、Softmax之后乘V得到Z:因为 Q K T QK^T QKT的维度是 ( h w , h w ) (hw,hw) (hw,hw),所以该过程的复杂度是​ ( h w ) 2 C (hw)^2C (hw)2C;
4、 Z Z Z乘矩阵 W Z W^Z WZ​得到最终输出:它的复杂度是​ h w C 2 hwC^2 hwC2
因此MSA计算复杂度为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C
式2,W-MSA计算复杂度分析如下,每个window有 M ∗ M M*M MM个Patch,
1、计算复杂度为, 3 ∗ ( h / M ) ∗ ( w / M ) ∗ M 2 ∗ C 2 = 3 h w C 2 3*(h/M)*(w/M)*M^2*C^2=3hwC^2 3(h/M)(w/M)M2C2=3hwC2;
2、计算复杂度为, ( h / M ) ∗ ( w / M ) ∗ ( M 2 ) ∗ ( M 2 ) ∗ C = h w M 2 C (h/M)*(w/M)*(M^2)*(M^2)*C=hwM^2C (h/M)(w/M)(M2)(M2)C=hwM2C
3、计算复杂度为, ( h / M ) ∗ ( w / M ) ∗ ( M 2 ) ∗ ( M 2 ) ∗ C = h w M 2 C (h/M)*(w/M)*(M^2)*(M^2)*C=hwM^2C (h/M)(w/M)(M2)(M2)C=hwM2C
4、计算复杂度为, ( h / M ) ∗ ( w / M ) ∗ M 2 C 2 = h w C 2 (h/M)*(w/M)*M^2C^2=hwC^2 (h/M)(w/M)M2C2=hwC2
因此W-MSA计算复杂度为 3 h w C 2 + 2 M 2 h w C 3hwC^2+2M^2hwC 3hwC2+2M2hwC

SW-MSA

固定窗口缺少窗口之间联系,限制模型表达能力。为此引入shifted window partition方法,如图2所示;
在这里插入图片描述
layer l使用均匀窗口分割(W-MSA),layer l+1使用shifted 窗口分割(SW-MSA),生成窗口穿过layer l窗口的边界。
使用shifted window partition将生成更多窗口,并且个别窗口比较小,如图2右侧红色窗口;
因此作者使用cyclic-shift,与常规窗口分割相比,窗口数量相等;如图4所示

在这里插入图片描述
但是如图4所示,经过shift后一些在特征图中本不相邻区域,出现在同一窗口,这些区域之间不应进行attention计算,因此作者引入mask机制;
如下图所示,源自issue38
在这里插入图片描述
左图为原图经过偏移后结果,以window1举例,QKV为区域1和区域2穿插,但区域1与区域2不想计算attention,区域2由左侧shift而来,因此对应attn_mask如右上角,0与-100如棋盘状穿插,attn_mask加入softmax计算,-100对应结果趋近于0,起到mask作用。

位置偏置

作者在计算Attention过程中增加位置相关偏置项,如式4,
在这里插入图片描述

结构变体

Swin结构变体包括以下几种:
在这里插入图片描述
Window大小默认M=7,MSA中每个head的序列维度d=32,C表示stage1中隐藏层channel数;layer numbers表示每个stage层数

实验

ImageNet分类

在ImageNet-1K数据上分类性能及计算量比较如表1所示,
在这里插入图片描述

COCO目标检测

图2a作者比较Swin-T与ResNet50在四个检测框架下性能;
图2b作者使用Cascade Mask R-CNN检测框架,比较不同模型容量的Swin Transformer与ResNe(X)t性能;在相似参数量、计算量、FPS下Swin Transformer均取得不错性能;
图2c最好结果与之前SOTA进行对比;
在这里插入图片描述

ADE20K语义分割

如表3
在这里插入图片描述

消融实验

如表4,在分类、检测、分割任务上,作者进行如下比较试验;

  1. 是否使用shift window,结果表明shift window性能优于固定窗口;
  2. 不同位置编码方式比较实验,不使用位置编码、ViT中绝对位置编码、式4中增加位置相关偏置项,结果表明增加位置相关偏置项性能提升;式4中去掉 √ d \surd d d项,性能下降;
    在这里插入图片描述
    作者比较不同attention方法,耗时情况,如表5
    在这里插入图片描述
    Swin Transformer略快于Performer,同时性能优于Performer,如表6,
    在这里插入图片描述

结论

作者提出一种视觉Transformer,Swin Trnsformer,通过W-MSA及SW-MSA降低计算复杂度,生成多层级特征,证明SW-MSA有效性;在COCO目标检测及ADE20K语义分割任务上均取得SOTA性能;

;