Bootstrap

Transformer架构中基于窗口的的自注意力机制W-MSA和滑动窗口自注意力机制SW-MSA的实现

在Transformer架构中,自注意力机制(Self-Attention Mechanism)是实现序列建模和上下文信息捕捉的核心机制。自注意力机制允许模型在处理序列数据时根据序列中各个位置的信息来动态地分配注意力。

在基于窗口的自注意力机制W-MSA中,通过设置一个固定大小的窗口来约束注意机制的计算范围,只计算窗口内的位置之间的相似度,并且将注意分数归一化得到权重。这样可以减少计算量,并保持局部依赖性。W-MSA可以有效地处理长序列,但窗口的大小会直接影响模型性能。

而滑动窗口自注意力机制SW-MSA是对W-MSA的改进。它引入了一个滑动窗口,通过多次计算W-MSA来覆盖整个序列。这样可以提高模型的感知范围,使得模型能够捕捉到整个序列中的全局依赖性。SW-MSA在处理长序列时可以在一定程度上解决窗口大小的限制问题。

W-MSA和SW-MSA是自注意力机制在Transformer架构中的两种变体。它们可以根据具体任务和序列长度的不同选择合适的方式来实现自注意力机制,以提高模型的性能和表达能力。

1.窗口的划分:

def window_partition(x, window_size):  # 窗口划分
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape  # 特征图的形状分别代表,一次处理的样本数量,宽,高,通道
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)  # 窗口划分HW/MM,window_size=M
    # x.view()用于重新塑造张量 x 的形状而不改变其底层数据
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    # contiguous():这是一个用于确保张量内存中元素连续排列的操作
    return windows

2.经过移动的窗口的复位

def window_reverse(windows, window_size, H, W):  # 窗口反转,便会原来的4x4的窗口
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))  # 一次性处理的patch的个数
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)  # 将向右下角移动的窗口归位
    # 将B分割成(H // window_size) x (W // window_size)个小窗口,每个小窗口大小为window_size x window_size,并按照一定的顺序重新排列它们
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)  # 维度变换之后再调整大小
    return x

窗口自注意力机制W-MSA

class WindowAttention(nn.Module):  # (SW-MSA)shift window multi self attention 滑动窗口注意力机制
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim  # 通道个数
        self.window_size = window_size  # 窗口的宽和高
        self.num_heads = num_heads  # 注意力的头目数量
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias,相对位置偏差
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww,个二维张量堆叠在一起,得到一个三维张量
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww计算点之间的相对位置嵌入
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0,右移window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1  # 所有元素向下移动window_size[1] - 1
        # 所有元素都向右下方移动 self.window_size[0]-1 个单位
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 将相对中心点的宽度偏移值乘以 (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww,-1 表示对张量的最后一个维度进行求和
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 三个 Q、K、V 分别对应着输出张量的前 dim,中间 dim 和后 dim 部分
        self.attn_drop = nn.Dropout(attn_drop)  # 随机令一些输入为0
        self.proj = nn.Linear(dim, dim)
        # 一个线性层,用于将输入张量中每个位置的特征向量映射到一个更高维度的输出张量。在这个例子中,输入张量的维度是 dim,通过线性变换后,输出张量的维度仍然是 dim

        self.proj_drop = nn.Dropout(proj_drop)  # 随机丢弃操作

        trunc_normal_(self.relative_position_bias_table, std=.02)  # 随机生成一些数值,并将其中超过 0.02 倍标准差的数值进行截断
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # 获取q,k,v的值
        # reshape()输出张量重新整形为一个五维张量
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # Q*K

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)  # Q*K+B

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)  # 归一化
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # softmax[(Q*K+B)]*V
        x = self.proj(x)  # 线性映射
        x = self.proj_drop(x)  # 随机丢弃一些输入
        return x

    def extra_repr(self) -> str:  # 返回一个字符串,描述了一些模块的关键参数信息
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
        # 阐述窗口的通道个数,窗口大小以及多头自注意力机制中的头的数量

    def flops(self, N):  # 计算模型中所需的浮点运算量(flops)的函数。它根据给定的窗口大小 N 来计算相应的 flops
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

4.基于滑动窗口SW-MSA自注意力机制的实现:

 

class SwinTransformerBlock(nn.Module):  # (SW-MSA)滑动窗口注意力机制的试下,RSTB
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim  # 输入通道的数量
        self.input_resolution = input_resolution  # 输入特征小图的分辨率
        self.num_heads = num_heads  # 多头的个数
        self.window_size = window_size  # 窗口大小,多少个patch
        self.shift_size = shift_size  # SW-MSA滑动窗口滑动大小
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:  # 如果窗口大小大于输入图片的大小,就不进行窗口划分
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)  # 窗口大小等于图像大小
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)  # LayerNorm1
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)  # MSA多头自注意力的计算

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)  # LayerNorm2
        mlp_hidden_dim = int(dim * mlp_ratio)  # 隐藏的输入通道的数量
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)  # MLP

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)  # 计算注意力机制中的掩码
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def calculate_mask(self, x_size):
        # calculate attention mask for SW-MSA
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1,初始化推向掩码
        h_slices = (slice(0, -self.window_size),  # slice 函数来生成用于切片操作的索引,从位置0开始到位置-self.window_size结束的切片操作
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))  # 高度的大小
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))  # 从位置-self.shift_size开始到最后位置结束的切片操作
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1  # 掩码的设置,需要的地方设置为1,不需要的地方设置为0

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1,掩码的窗口划分
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 维度调整
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 创建注意力机制中的掩码张量,以便在自注意力或其他形式的掩码注意力中屏蔽掉不需要的位置
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        # masked_fill 函数的作用是将输入张量中满足掩码条件的元素替换为指定的值

        return attn_mask

    def forward(self, x, x_size):
        H, W = x_size
        B, L, C = x.shape
        # assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)  # 定义第一个LayerNorm层
        x = x.view(B, H, W, C)

        # cyclic shift实现窗口的滑动
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            # 窗口滚动,于输入张量 x,shifts 参数用于指定每个维度需要滚动的步数,dims 参数用于指定需要滚动的维度
        else:
            shifted_x = x

        # partition windows窗口划分
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C滚动后的窗口划分
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
        if self.input_resolution == x_size:
            attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C,窗口注意力机制W-MSA
        else:
            attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))  # 计算窗口注意力W-MSA

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C,滑动窗口复位

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))  # 滚动的方式复位
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)  # 第一个残差连接
        x = x + self.drop_path(self.mlp(self.norm2(x)))  # MLP加上残差连接

        return x

    def extra_repr(self) -> str:  # 打印模型细节信息
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):  # 计算量和模型复杂度
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

;