Bootstrap

video-swin-transfomer代码讲解

原文链接:https://blog.csdn.net/ly59782/article/details/120823052

Swin-Transformer和Video Swin-Transformer大同小异,感觉最大的区别就是2D的改到了3D,其实操作都是一样的,就是多了一个维度,所以主要还是基于2d讲解的,然后类比一下3d就好啦,讲的是tiny版本的。

源码git传送门:https://github.com/SwinTransformer/Video-Swin-Transformer

目录

类定义

预处理

stage

block

W-MSA

SW-MSA


  

类定义

首先看类定义,主要的函数如下


 
 
  1. class SwinTransformer3D(nn.Module):
  2. """ Swin Transformer backbone.
  3. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  4. """
  5. def __init__( self,
  6. pretrained=None,
  7. pretrained2d=True,
  8. # 原swin-transformer是4(然后tuple到4x4),而这里是4x4x4,多了一个时间维度
  9. patch_size=(4,4,4),
  10. in_chans=3,
  11. embed_dim=96,
  12. depths=[2, 2, 6, 2],
  13. num_heads=[3, 6, 12, 24],
  14. window_size=(2,7,7),
  15. mlp_ratio=4.,
  16. qkv_bias=True,
  17. qk_scale=None,
  18. drop_rate=0.,
  19. attn_drop_rate=0.,
  20. drop_path_rate=0.2,
  21. norm_layer=nn.LayerNorm,
  22. patch_norm=False,
  23. frozen_stages=-1,
  24. use_checkpoint=False):
  25. super().__init__()
  26. self.pretrained = pretrained
  27. self.pretrained2d = pretrained2d
  28. self.num_layers = len(depths)
  29. self.embed_dim = embed_dim
  30. self.patch_norm = patch_norm
  31. self.frozen_stages = frozen_stages
  32. self.window_size = window_size
  33. self.patch_size = patch_size
  34. """
  35. # 预处理图片序列到patch_embed,对应流程图中的Linear Embedding,
  36. # 具体做法是用3d卷积,形状变化为BCDHW -> B,C,D,Wh,Ww 即(B,96,T/4,H/4,W/4),
  37. # 要注意的是,其实在stage 1之前,即预处理完成后,已经是流程图上的T/4 × H/4 × W/4 × 96
  38. """
  39. # split image into non-overlapping patches
  40. self.patch_embed = PatchEmbed3D(
  41. patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
  42. norm_layer=norm_layer if self.patch_norm else None)
  43. """
  44. # ViT在输入会给embedding进行位置编码.实验证明位置编码效果不好
  45. # 所以Swin-T把它作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
  46. # 这里video-Swin-T 直接去掉了位置编码
  47. # ViT会单独加上一个可学习参数,作为分类的token.
  48. # 而Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层
  49. """
  50. # 经过一层dropout,至此预处理结束
  51. self.pos_drop = nn.Dropout(p=drop_rate)
  52. """
  53. # 流程图中每个stage,即代码中的BasicLayer,由若干个block组成,
  54. # 而block的数目由depths列表中的元素决定,这里是[2,2,6,2].
  55. # 每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),
  56. # 一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA.
  57. # 前三个stage的最后会用PatchMerging进行下采样(代码中是前三个stage每个stage最后,流程图上画的是后三个,每个stage最前面做,其实是一样的)
  58. # 操作为将临近2*2范围内的patch(即4个为一组)按通道cat起来,经过一个layernorm和linear层, 实现维度下采样、特征加倍的效果,具体见PatchMerging类注释
  59. """
  60. # stochastic depth
  61. # 随机深度,用这个来让每个stage中的block数目随机变化,达到随机深度的效果
  62. # torch.linspace()生成0到0.2的12个数构成的等差数列,如下
  63. # [0, 0.01818182, 0.03636364, 0.05454545, 0.07272727 0.09090909,
  64. # 0.10909091, 0.12727273, 0.14545455, 0.16363636, 0.18181818, 0.2]
  65. dpr = [x.item() for x in torch.linspace( 0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  66. # build layers
  67. self.layers = nn.ModuleList()
  68. for i_layer in range(self.num_layers): # 流程图中的4个stage,对应代码中4个layers
  69. layer = BasicLayer(
  70. dim= int(embed_dim * 2**i_layer), #96 x 2^n,对应流程图上的C,2C,4C,8C
  71. depth=depths[i_layer], #[2,2,6,2]
  72. num_heads=num_heads[i_layer], #[3, 6, 12, 24],
  73. window_size=window_size, # (8,7,7)
  74. mlp_ratio=mlp_ratio, # 4
  75. qkv_bias=qkv_bias, # True
  76. qk_scale=qk_scale, # None
  77. drop=drop_rate, # 0
  78. attn_drop=attn_drop_rate, # 0
  79. drop_path=dpr[ sum(depths[:i_layer]): sum(depths[:i_layer + 1])], # 依据上面算的dpr
  80. norm_layer=norm_layer, # nn.LayerNorm
  81. downsample=PatchMerging if i_layer<self.num_layers- 1 else None, # 前三个stage后要用PatchMerging下采样,
  82. use_checkpoint=use_checkpoint)
  83. self.layers.append(layer)
  84. self.num_features = int(embed_dim * 2**(self.num_layers- 1)) # 96*8
  85. # add a norm layer for each output
  86. self.norm = norm_layer(self.num_features)
  87. self._freeze_stages()

预处理

预处理图片序列到patch_embed,对应流程图中的Linear Embedding,具体做法是用3d卷积,从BCDHW->B,C,D,Wh,Ww 即(B,96,T/4,H/4,W/4),以后都假设HW为224X224,T为32,那么形状为(B,96,8,56,56),最后经过一层dropout,至此预处理结束 。要注意的是,其实在stage 1之前,即预处理完成后,已经是流程图上的T/4 × H/4 × W/4 × 96。主要函数实现:


 
 
  1. class PatchEmbed3D(nn.Module):
  2. """ Video to Patch Embedding.
  3. Args:
  4. patch_size (int): Patch token size. Default: (2,4,4).
  5. in_chans (int): Number of input video channels. Default: 3.
  6. embed_dim (int): Number of linear projection output channels. Default: 96.
  7. norm_layer (nn.Module, optional): Normalization layer. Default: None
  8. """
  9. def __init__( self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
  10. super().__init__()
  11. self.patch_size = patch_size
  12. self.in_chans = in_chans
  13. self.embed_dim = embed_dim
  14. self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  15. if norm_layer is not None:
  16. self.norm = norm_layer(embed_dim)
  17. else:
  18. self.norm = None
  19. def forward( self, x):
  20. """Forward function."""
  21. # padding
  22. _, _, D, H, W = x.size() #BCDHW
  23. #DHW正好对应patch_size[0],patch_size[1],patch_size[2],防止除不开先pad
  24. if W % self.patch_size[ 2] != 0:
  25. x = F.pad(x, ( 0, self.patch_size[ 2] - W % self.patch_size[ 2]))
  26. if H % self.patch_size[ 1] != 0:
  27. x = F.pad(x, ( 0, 0, 0, self.patch_size[ 1] - H % self.patch_size[ 1]))
  28. if D % self.patch_size[ 0] != 0:
  29. x = F.pad(x, ( 0, 0, 0, 0, 0, self.patch_size[ 0] - D % self.patch_size[ 0]))
  30. x = self.proj(x) # B C D Wh Ww, 其中D Wh Ww表示经过3d卷积后特征的大小
  31. if self.norm is not None: #默认会使用nn.LayerNorm,所以下面程序必运行
  32. D, Wh, Ww = x.size( 2), x.size( 3), x.size( 4)
  33. x = x.flatten( 2).transpose( 1, 2) #B, C, D, Wh, Ww -> B, C, D*Wh*Ww ->B,D*Wh*Ww, C
  34. #因为要层归一化,所以要拉成上面的形状,把C放在最后
  35. x = self.norm(x)
  36. x = x.transpose( 1, 2).view(- 1, self.embed_dim, D, Wh, Ww) #又拉回 B, C, D, Wh, Ww
  37. return x

ViT在输入会给embedding进行位置编码。实验证明位置编码效果不好所以Swin-T把它作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码(见下文中block部分的W-MSA)。这里video-Swin-T 直接去掉了位置编码

stage

流程图中每个stage,对应代码中的BasicLayer,由若干个block组成,而block的数目由depths列表中的元素决定,这里是[2,2,6,2]. 每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA. 前三个stage的最后会用PatchMerging进行下采样.(代码中是前三个stage每个stage最后,流程图上画的是后三个,每个stage最前面做,其实是一样的). 操作为将临近2*2范围内的patch(即4个为一组)按通道cat起来,经过一个layernorm和linear层, 实现维度下采样、特征加倍的效果,具体见PatchMerging类注释


 
 
  1. class BasicLayer(nn.Module):
  2. """ A basic Swin Transformer layer for one stage.
  3. """
  4. def __init__( self,
  5. dim, # 以第一层为例 为96
  6. depth, #以第一层为例 为2
  7. num_heads, #以第一层为例 为3
  8. window_size=(1,7,7), # (8,7,7)
  9. mlp_ratio=4.,
  10. qkv_bias=False, #true
  11. qk_scale=None,
  12. drop=0.,
  13. attn_drop=0.,
  14. drop_path=0.,#以第一层为例 为[0, 0.01818182]
  15. norm_layer=nn.LayerNorm,
  16. downsample=None, #PatchMerging
  17. use_checkpoint=False):
  18. super().__init__()
  19. self.window_size = window_size # (8,7,7)
  20. self.shift_size = tuple(i // 2 for i in window_size) #(4,3,3)
  21. self.depth = depth # 2
  22. self.use_checkpoint = use_checkpoint
  23. # build blocks
  24. self.blocks = nn.ModuleList([
  25. SwinTransformerBlock3D(
  26. dim=dim, #96
  27. num_heads=num_heads, # 3
  28. window_size=window_size,
  29. # 第一个block的shiftsize=(0,0,0),也就是W-MSA不进行shift,第2个shiftsize=(4,3,3)
  30. shift_size=( 0, 0, 0) if (i % 2 == 0) else self.shift_size,
  31. mlp_ratio=mlp_ratio,
  32. qkv_bias=qkv_bias, # true
  33. qk_scale=qk_scale, # None
  34. drop=drop,
  35. attn_drop=attn_drop,
  36. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  37. norm_layer=norm_layer,
  38. use_checkpoint=use_checkpoint,
  39. )
  40. for i in range(depth)]) # depth = 2
  41. self.downsample = downsample
  42. if self.downsample is not None:
  43. self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  44. def forward( self, x):
  45. """ Forward function.
  46. """
  47. # calculate attention mask for SW-MSA
  48. B, C, D, H, W = x.shape
  49. window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
  50. x = rearrange(x, 'b c d h w -> b d h w c')
  51. Dp = int(np.ceil(D / window_size[ 0])) * window_size[ 0] # 1*8
  52. Hp = int(np.ceil(H / window_size[ 1])) * window_size[ 1] # 56/7 *7
  53. Wp = int(np.ceil(W / window_size[ 2])) * window_size[ 2] # 56/7 *7
  54. # 计算一个attention_mask用于SW-MSA,怎么shitfed以及mask如何推导见后文
  55. attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) # (8,7,7) (0,3,3)
  56. # 以第一个stage为例,里面有2个block,第一个block进行W-MSA,第二个block进行SW-MSA
  57. # 如何W-MSA SW-MSA 见下述
  58. for blk in self.blocks:
  59. x = blk(x, attn_mask)
  60. #改变形状,把C放到最后一维度(因为PatchMerging里有layernom和全连接层)
  61. x = x.view(B, D, H, W, - 1)
  62. # 用PatchMerging 进行patch的拼接和全连接层 实现下采样
  63. if self.downsample is not None:
  64. x = self.downsample(x)
  65. x = rearrange(x, 'b d h w c -> b c d h w')
  66. return x

 


 
 
  1. class PatchMerging(nn.Module):
  2. """ Patch Merging Layer
  3. """
  4. def __init__( self, dim, norm_layer=nn.LayerNorm):
  5. super().__init__()
  6. self.dim = dim
  7. #用全连接层把C由4C->2C,因为是4个cat一起所以是4C
  8. self.reduction = nn.Linear( 4 * dim, 2 * dim, bias= False)
  9. self.norm = norm_layer( 4 * dim)
  10. def forward( self, x):
  11. """ Forward function.
  12. """
  13. B, D, H, W, C = x.shape
  14. # padding
  15. pad_input = (H % 2 == 1) or (W % 2 == 1)
  16. if pad_input:
  17. x = F.pad(x, ( 0, 0, 0, W % 2, 0, H % 2))
  18. x0 = x[:, :, 0:: 2, 0:: 2, :] # B D H/2 W/2 C
  19. x1 = x[:, :, 1:: 2, 0:: 2, :] # B D H/2 W/2 C
  20. x2 = x[:, :, 0:: 2, 1:: 2, :] # B D H/2 W/2 C
  21. x3 = x[:, :, 1:: 2, 1:: 2, :] # B D H/2 W/2 C
  22. # 每2X2个patch cat到一起
  23. x = torch.cat([x0, x1, x2, x3], - 1) # B D H/2 W/2 4*C
  24. x = self.norm(x) # 层归一化
  25. x = self.reduction(x) # 全连接层 降维
  26. return x

block

首先梳每个block的理整体脉络,和普通的transformer的encoder一样,只不过把MSA变成W-MSA或者SW-MSA


 
 
  1. class SwinTransformerBlock3D(nn.Module):
  2. """ Swin Transformer Block.
  3. """
  4. def __init__( self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
  5. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  6. act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
  7. super().__init__()
  8. self.dim = dim
  9. self.num_heads = num_heads
  10. self.window_size = window_size
  11. self.shift_size = shift_size
  12. self.mlp_ratio = mlp_ratio
  13. self.use_checkpoint=use_checkpoint
  14. assert 0 <= self.shift_size[ 0] < self.window_size[ 0], "shift_size must in 0-window_size"
  15. assert 0 <= self.shift_size[ 1] < self.window_size[ 1], "shift_size must in 0-window_size"
  16. assert 0 <= self.shift_size[ 2] < self.window_size[ 2], "shift_size must in 0-window_size"
  17. self.norm1 = norm_layer(dim)
  18. self.attn = WindowAttention3D(
  19. dim, window_size=self.window_size, num_heads=num_heads,
  20. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  21. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  22. self.norm2 = norm_layer(dim)
  23. mlp_hidden_dim = int(dim * mlp_ratio)
  24. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  25. def forward_part1( self, x, mask_matrix):
  26. B, D, H, W, C = x.shape
  27. # 1 先计算出当前block的window_size, 和shift_size
  28. window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
  29. # 2 经过一个layer_norm
  30. x = self.norm1(x)
  31. # pad一下特征图避免除不开
  32. # pad feature maps to multiples of window size
  33. pad_l = pad_t = pad_d0 = 0
  34. pad_d1 = (window_size[ 0] - D % window_size[ 0]) % window_size[ 0]
  35. pad_b = (window_size[ 1] - H % window_size[ 1]) % window_size[ 1]
  36. pad_r = (window_size[ 2] - W % window_size[ 2]) % window_size[ 2]
  37. x = F.pad(x, ( 0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
  38. _, Dp, Hp, Wp, _ = x.shape
  39. # 3 判断是否需要对特征图进行shift
  40. # cyclic shift
  41. if any(i > 0 for i in shift_size):
  42. shifted_x = torch.roll(x, shifts=(-shift_size[ 0], -shift_size[ 1], -shift_size[ 2]), dims=( 1, 2, 3))
  43. attn_mask = mask_matrix
  44. else:
  45. shifted_x = x
  46. attn_mask = None
  47. # 4 将特征图切成一个个的窗口(都是reshape操作)
  48. # partition windows
  49. x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
  50. # 5 通过attn_mask是否为None判断进行W-MSA还是SW-MSA
  51. # W-MSA/SW-MSA
  52. attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
  53. # 6 把窗口在合并回来,看成4的逆操作,同样都是reshape操作
  54. # merge windows
  55. attn_windows = attn_windows.view(- 1, *(window_size+(C,))) #(B*num_windows, window_size, window_size, C)
  56. shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
  57. # 7 如果之前shitf过,也要还原回去
  58. # reverse cyclic shift
  59. if any(i > 0 for i in shift_size):
  60. x = torch.roll(shifted_x, shifts=(shift_size[ 0], shift_size[ 1], shift_size[ 2]), dims=( 1, 2, 3))
  61. else:
  62. x = shifted_x
  63. # 去掉pad
  64. if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
  65. x = x[:, :D, :H, :W, :].contiguous()
  66. return x
  67. def forward_part2( self, x):
  68. # 经过FFN
  69. return self.drop_path(self.mlp(self.norm2(x)))
  70. def forward( self, x, mask_matrix):
  71. """ Forward function.
  72. Args:
  73. x: Input feature, tensor size (B, D, H, W, C).
  74. mask_matrix: Attention mask for cyclic shift.
  75. """
  76. # tranformer的常规操作,包含MSA、残差连接、dropout、FFN,只不过MSA变成W-MSA或者SW-MSA
  77. shortcut = x
  78. if self.use_checkpoint:
  79. x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
  80. else:
  81. x = self.forward_part1(x, mask_matrix)
  82. x = shortcut + self.drop_path(x)
  83. if self.use_checkpoint:
  84. x = x + checkpoint.checkpoint(self.forward_part2, x)
  85. else:
  86. x = x + self.forward_part2(x)
  87. return x

W-MSA

先来看没有Shift的基于Window的注意力机制是如何做的,传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量,主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码


 
 
  1. class WindowAttention3D(nn.Module):
  2. """ Window based multi-head self attention (W-MSA) module with relative position
  3. """
  4. def __init__( self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
  5. super().__init__()
  6. self.dim = dim
  7. self.window_size = window_size # Wd, Wh, Ww
  8. self.num_heads = num_heads
  9. head_dim = dim // num_heads # 每个注意力头对应的通道数
  10. self.scale = qk_scale or head_dim ** - 0.5
  11. # define a parameter table of relative position bias
  12. # 设置一个形状为(2*Wd-1*2*(Wh-1) * 2*(Ww-1), nH)的可学习变量 ,用于后续的位置编码
  13. self.relative_position_bias_table = nn.Parameter(
  14. torch.zeros(( 2 * window_size[ 0] - 1) * ( 2 * window_size[ 1] - 1) * ( 2 * window_size[ 2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
  15. # 获取窗口内每对token的相对位置索引
  16. # get pair-wise relative position index for each token inside the window
  17. coords_d = torch.arange(self.window_size[ 0])
  18. coords_h = torch.arange(self.window_size[ 1])
  19. coords_w = torch.arange(self.window_size[ 2])
  20. coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
  21. coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
  22. #利用广播机制 ,分别在第二维 ,第一维 ,插入一个维度 ,进行广播相减 ,得到 3, Wd*Wh*Ww, Wd*Wh*Ww的张量
  23. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
  24. relative_coords = relative_coords.permute( 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
  25. #因为采取的是相减 ,所以得到的索引是从负数开始的 ,所以加上偏移量 ,让其从0开始
  26. relative_coords[:, :, 0] += self.window_size[ 0] - 1 # shift to start from 0
  27. relative_coords[:, :, 1] += self.window_size[ 1] - 1
  28. relative_coords[:, :, 2] += self.window_size[ 2] - 1
  29. # 后续我们需要将其展开成一维偏移量 而对于(1 ,2)和(2 ,1)这两个坐标 在二维上是不同的,
  30. # 但是通过将x,y坐标相加转换为一维偏移的时候,他的偏移量是相等的,所以对其做乘法以进行区分
  31. relative_coords[:, :, 0] *= ( 2 * self.window_size[ 1] - 1) * ( 2 * self.window_size[ 2] - 1)
  32. relative_coords[:, :, 1] *= ( 2 * self.window_size[ 2] - 1)
  33. #在最后一维上进行求和 ,展开成一个一维坐标 ,并注册为一个不参与网络学习的常量
  34. relative_position_index = relative_coords. sum(- 1) # Wd*Wh*Ww, Wd*Wh*Ww
  35. self.register_buffer( "relative_position_index", relative_position_index)
  36. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  37. self.attn_drop = nn.Dropout(attn_drop)
  38. self.proj = nn.Linear(dim, dim)
  39. self.proj_drop = nn.Dropout(proj_drop)
  40. # 截断正态分布初始化
  41. trunc_normal_(self.relative_position_bias_table, std= .02)
  42. self.softmax = nn.Softmax(dim=- 1)
  43. def forward( self, x, mask=None):
  44. """ Forward function.
  45. Args:
  46. x: input features with shape of (num_windows*B, N, C)
  47. mask: (0/-inf) mask with shape of (num_windows, N, N) or None
  48. """
  49. # numWindows*B, N, C ,其中N=window_size_d * window_size_h * window_size_w
  50. B_, N, C = x.shape
  51. # 然后经过self.qkv这个全连接层后进行reshape到(3, numWindows*B, num_heads,N, c//num_heads)
  52. # 3表示3个向量,刚好分配给q,k,v,
  53. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute( 2, 0, 3, 1, 4)
  54. q, k, v = qkv[ 0], qkv[ 1], qkv[ 2] # B_, nH, N, C
  55. # 根据公式,对q乘以一个scale缩放系数,
  56. # 然后与k(为了满足矩阵乘要求,需要将最后两个维度调换)进行相乘.
  57. # 得(numWindows*B, num_heads, N, N)的attn张量
  58. q = q * self.scale # selfattention公式里的根号下dk
  59. attn = q @ k.transpose(- 2, - 1)
  60. # 之前我们针对位置编码设置了个形状为(2*Wd-1*2*(Wh-1) * 2*(Ww-1), numHeads)的可学习变量.
  61. # 我们用计算得到的相对编码位置索引self.relative_position_index选取,
  62. # 得到形状为(nH, Wd*Wh*Ww, Wd*Wh*Ww)的编码,加到attn张量上
  63. relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(- 1)].reshape(
  64. N, N, - 1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
  65. relative_position_bias = relative_position_bias.permute( 2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
  66. attn = attn + relative_position_bias.unsqueeze( 0) # B_, nH, N, N
  67. # 剩下就是跟transformer一样的softmax,dropout,与V矩阵乘,再经过一层全连接层和dropout
  68. if mask is not None:
  69. # mask.shape = nW, N, N, 其中N = Wd*Wh*Ww
  70. nW = mask.shape[ 0]
  71. # 将mask加到attention的计算结果再进行softmax,
  72. # 由于mask的值设置为-100,softmax后就会忽略掉对应的值,从而达到mask的效果
  73. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 1).unsqueeze( 0)
  74. attn = attn.view(- 1, self.num_heads, N, N)
  75. attn = self.softmax(attn)
  76. else:
  77. attn = self.softmax(attn)
  78. attn = self.attn_drop(attn)
  79. x = (attn @ v).transpose( 1, 2).reshape(B_, N, C)
  80. x = self.proj(x)
  81. x = self.proj_drop(x)
  82. return x

SW-MSA

SW-MSA,这里比较复杂,是swinTransformer精髓之处

首先理解下如何shitfed的

为什么cyclic shift? 图一可以看出,partition后Windows的数量变多了,从4个变成了9个大小不一致的窗,我们希望每个window是单独做attention的,for循环做显然不好。其实在代码里,是通过对特征图移位实现的,把切成边角料的小块又拼在一起,把A拼接到右下角,C向下平移,B向右平移,最后组合成4个大小一致的window。但这又引入一个问题, 例如右下角的窗口由好几个小窗组成,上面说到了我们希望每个window是单独做attention的,所以引入mask,保证A窗不与C窗进行attention。

代码里对特征图移位是通过torch.roll来实现的,下面是示意图

为什么mask?

 我们给window编号(下左图),然后按上面讲的shitf,得到右图

 我们有提到过,希望每个窗口内的内容单独做注意力机制,也就是说希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。如下图,只取有颜色的部分而忽略灰色部分,这就用到mask,让灰色部分的值为-100,softmax后忽略掉对应的值,有色部分为0

如何计算得到mask的?

首先上代码,slice表示切片操作,我们以二维为例讲解,先不考虑d维度。所以h和w都是在(0,-7),(-7,-3),(-3,None)切片循环的,然后给不同切片的位置填上标号


 
 
  1. def compute_mask( D, H, W, window_size, shift_size, device):
  2. img_mask = torch.zeros(( 1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
  3. cnt = 0
  4. # 切片操作,假设不看d维度,见详解图
  5. for d in slice(-window_size[ 0]), slice(-window_size[ 0], -shift_size[ 0]), slice(-shift_size[ 0], None):
  6. for h in slice(-window_size[ 1]), slice(-window_size[ 1], -shift_size[ 1]), slice(-shift_size[ 1], None):
  7. for w in slice(-window_size[ 2]), slice(-window_size[ 2], -shift_size[ 2]), slice(-shift_size[ 2], None):
  8. img_mask[:, d, h, w, :] = cnt
  9. cnt += 1
  10. mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
  11. mask_windows = mask_windows.squeeze(- 1) # nW, ws[0]*ws[1]*ws[2]
  12. # nW, 1, ws[0]*ws[1]*ws[2] - nW, ws[0]*ws[1]*ws[2],1会触发广播机制,将维度不匹配维度中维度为1的复制然后匹配上
  13. attn_mask = mask_windows.unsqueeze( 1) - mask_windows.unsqueeze( 2)
  14. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(- 100.0)).masked_fill(attn_mask == 0, float( 0.0))
  15. return attn_mask

按照上述填写编号的代码,假设窗口大小M=7,图片H=2M,W=2M,shiftwindow_size=M//2=3,我们的到如下图的mask,1表示那一部分区域里面值全填1。我们看这个图和上面讲的shitf后的窗口其实是一样的,有4个window,其中3个window是由不同小窗口组成的,我们要进行mask,

 按照代码接下来进行window_partition,使形状变为(B*num_windows, window_size*window_size, C),即(nW,M^2,1),window_partition函数内全是reshape操作,这里不展开。

然后squeeze去掉最后一个维度,然后做了一个减法 mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2),也就是(nW,1,M^2)-(nW,M^2,1),此时会触发广播机制,将维度不匹配维度中维度为1的复制然后匹配上,以第二个window为例,

首先M^2的向量画出来就是

 (nW,1,M^2)会把原来1行M^2列的向量复制M^2行,得到下图A

(nW,M^2,1)会把原来M^2行1列的向量复制M^2列,得到下图B

 

 然后A-B,每一个小块就变成了下图,然后把非0的地方填充-100,在后续代码中会忽略这些位置的值来实现mask

 最后自己可以在脑海中想象下加上D维度之后3维的操作,其实是一样的。

SW-MSA前向传播中不同的代码地方为


 
 
  1. if mask is not None:
  2. # mask.shape = nW, N, N, 其中N = Wd*Wh*Ww
  3. nW = mask.shape[ 0]
  4. # 将mask加到attention的计算结果再进行softmax,
  5. # 由于mask的值设置为-100,softmax后就会忽略掉对应的值,从而达到mask的效果
  6. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 1).unsqueeze( 0)
  7. attn = attn.view(- 1, self.num_heads, N, N)
  8. attn = self.softmax(attn)
  9. else:
  10. attn = self.softmax(attn)

就是比W-MSA在attn结果上多加了一个mask的值,使不想要的位置的值无限小,softmax后就会被忽略,从而达到mask的效果。

参考文献:

2021-Swin Transformer Attention机制的详细推导_小毛激励我好好学习的博客-CSDN博客

图解swin transformer【附代码解读】-技术圈 (proginn.com)

;