Bootstrap

swin-transformer的理解以及tensorflow2的实现

实现了这个就实现了swin-block部分。我觉得原理还挺简单的,实现代码是真的有点麻烦。shift那里还好,主要是那个mask和relative_pos,reshape和transpose吐了

custom_function.py

from tensorflow.keras.layers import Input
import tensorflow as tf


def window_partition(x,window_size):
    _,H,W,C = x.shape.as_list()
    # print(H,W,C)
    x = tf.reshape(x,shape=[-1,H//window_size,window_size,
                            W//window_size,window_size,C])
    # -> B,nH,nW,w,w,C
    x = tf.transpose(x,[0,1,3,2,4,5])
    windows = tf.reshape(x,shape=[-1,window_size,window_size,C])
    return windows

def window_reverse(windows,window_size,H,W,C):
    # print(f'in window_reverse, {windows.shape}')
    x = tf.reshape(windows,shape=[-1,H//window_size,W//window_size,
                                window_size,window_size,C])
    x = tf.transpose(x,[0,1,3,2,4,5])
    x = tf.reshape(x,shape=[-1,H,W,C])
    return x

def drop_path(inputs,drop_prob,is_training):
    if (not is_training) or (drop_prob==0.):
        return inputs
    keep_prob = 1.0 - drop_prob

    random_tensor = keep_prob
    shape = (tf.shape(inputs)[0],) + (1,)*(len(tf.shape(inputs))-1)
    random_tensor +=(tf.random.uniform(shape,dtype=inputs.dtype))
    binary_tensor = tf.floor(random_tensor)
    output = tf.math.divide(inputs,keep_prob) * binary_tensor
    return output


# if __name__ == '__main__':
#     inputs = Input(shape=[56,56,96],batch_size=2)
#     windows = window_partition(inputs,7)
#     print(windows.shape)
#     x = window_reverse(windows,7,56,56,96)
#     print(x.shape)

custom_layer.py

from tensorflow.keras.layers import (Layer,Input,LayerNormalization,
                                    Dense,Dropout,Conv2D,)
from tensorflow.keras.activations import gelu
import tensorflow as tf
import numpy as np
from custom_function import (drop_path)
from custom_function import window_partition, window_reverse


class MLPLayer(Layer):
    def __init__(self,hidden_features=None,drop_rate=0.,**kwargs):
        super(MLPLayer,self).__init__(**kwargs)

        self.hidden_features = hidden_features
        self.drop_rate = drop_rate

        self.fc1 = Dense(self.hidden_features)
        self.drop = Dropout(self.drop_rate)
    
    def get_config(self):
        config = super(MLPLayer,self).get_config()
        config.update({"hidden_features":self.hidden_features,
                       "out_features":self.out_features,
                       "drop_rate":self.drop_rate})
        return config
    
    def build(self, input_shape):
        self.out_features = input_shape[-1]
        self.fc2 = Dense(self.out_features)

    def call(self,inputs):
        x = self.fc1(inputs)
        x = gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x

class WindowAttentionLayer(Layer):
    def __init__(self,dim,window_size,num_heads,qkv_bias=True,
                qk_scale=None,attn_drop_rate=0.,
                proj_drop_rate=0.,**kwargs):
        super(WindowAttentionLayer,self).__init__(**kwargs)

        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim//num_heads
        self.scale = qk_scale or (self.head_dim ** (-0.5))
        self.qkv_bias = qkv_bias
        self.attn_drop_rate = attn_drop_rate
        self.proj_drop_rate = proj_drop_rate

        self.qkv = Dense(self.dim*3,use_bias=self.qkv_bias)
        self.attn_drop = Dropout(self.attn_drop_rate)
        self.proj = Dense(self.dim)
        self.proj_drop = Dropout(self.proj_drop_rate)


    def get_config(self):
        config = super(WindowAttentionLayer,self).get_config()
        config.update({"self.dim":self.dim,
                       "window_size":self.window_size,
                       "num_heads":self.num_heads,
                       "head_dim":self.head_dim,
                       "scale":self.scale,
                       "qkv_bias":self.qkv_bias,
                       "attn_drop_rate":self.attn_drop_rate,
                       "proj_drop_rate":self.proj_drop_rate})
        return config
    
    def build(self, input_shape):
        self.relative_position_bias_table = self.add_weight(
            shape=[(2*self.window_size[0]-1)*(2*self.window_size[1]-1),
                    self.num_heads],
            initializer=tf.initializers.Zeros(),
            trainable=True
        )

        coords_h = np.arange(self.window_size[0]) # 0-6
        coords_w = np.arange(self.window_size[1])
        coords = np.stack(np.meshgrid(coords_h,coords_w,indexing='ij'))
        coords_flatten = coords.reshape(2,-1)
        relative_coords = coords_flatten[:,:,None] - coords_flatten[:,None,:]
        relative_coords = relative_coords.transpose([1,2,0])
        relative_coords[:,:,0] +=self.window_size[0] - 1
        relative_coords[:,:,1] +=self.window_size[1] - 1
        relative_coords[:,:,0] *= 2*self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1).astype(np.int64)
        self.relative_position_index = tf.Variable(
            initial_value=tf.convert_to_tensor(relative_position_index),
            trainable=False
        )
        self.built = True

    def call(self,x,mask=None):
        _,N,C = x.shape.as_list()
        qkv = self.qkv(x)
        q,k,v = tf.split(qkv,3,axis=-1) # -1,49,96
        # -1,8,49,12
        q = tf.transpose(tf.reshape(q,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
        k = tf.transpose(tf.reshape(k,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
        v = tf.transpose(tf.reshape(v,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
        

        q = self.scale * q
        # -> (-1, 8, 49, 49)
        attn = tf.matmul(q,k,transpose_b=True)
        # print(f'q*k之后的shape: {attn.shape}')
        relative_position_bias = tf.gather(
            self.relative_position_bias_table,
            tf.reshape(self.relative_position_index,shape=[-1])
        )
        relative_position_bias = tf.reshape(relative_position_bias,
                    shape=[self.window_size[0]*self.window_size[1],
                           self.window_size[0]*self.window_size[1],
                           -1])
        relative_position_bias = tf.transpose(relative_position_bias,
                                            [2,0,1])
        # print(f'relative_pos的shape: {relative_position_bias.shape}')
        attn = attn + tf.expand_dims(relative_position_bias,axis=0)
        # print(f'in winattn: {mask.shape}')
        if type(mask) != type(None):
            mask = tf.convert_to_tensor(mask)
            nW = mask.shape[0]
            
            attn = tf.reshape(attn,shape=[-1,nW,self.num_heads,N,N]) + \
                    tf.cast(tf.expand_dims(tf.expand_dims(mask,axis=1),axis=0),
                    attn.dtype)
            attn = tf.reshape(attn,shape=[-1,self.num_heads,N,N])
            attn = tf.nn.softmax(attn,axis=-1)
        else:
            attn = tf.nn.softmax(attn,axis=-1)
        
        attn = self.attn_drop(attn)
        # -> -1,49,8,12
        x = tf.transpose((attn@v),[0,2,1,3])
        # -> -1,49,96
        x = tf.reshape(x,shape=[-1,N,C])
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class DropPathLayer(Layer):
    def __init__(self,drop_prob=None,**kwargs):
        super(DropPathLayer,self).__init__(**kwargs)
        self.drop_prob = drop_prob
    
    def call(self,x,training=None):
        return drop_path(x,self.drop_prob,training)
    
    def get_config(self):
        config = super(DropPathLayer,self).get_config()
        config.update({"drop_prob":self.drop_prob})
        return config


class SwinTransformerBlockLayer(Layer):
    def __init__(self,dim,input_resolution,num_heads,window_size=7,
                shift_size=0,mlp_ratio=4.,qkv_bias=True,qk_scale=None,
                drop_rate=0.,attn_drop_rate=0.,drop_path_prob=0.,
                **kwargs):
        super(SwinTransformerBlockLayer,self).__init__(**kwargs)

        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.drop_rate = drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.drop_path_prob = drop_path_prob


        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        
        assert 0<=self.shift_size<self.window_size,'偏移必须在0-window_size之间'

        self.norm1 = LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttentionLayer(self.dim,(self.window_size,self.window_size),
                                        self.num_heads,self.qkv_bias,self.qk_scale,
                                        self.attn_drop_rate,self.drop_rate)
        self.drop_path = DropPathLayer(self.drop_path_prob)
        self.norm2 = LayerNormalization(epsilon=1e-5)
        mlp_hidden_dim = int(dim*self.mlp_ratio)
        self.mlp = MLPLayer(hidden_features=mlp_hidden_dim,
                            drop_rate=self.drop_rate)
        
    def build(self,input_shape):
        if self.shift_size > 0:
            H,W = self.input_resolution
            img_mask = np.zeros([1,H,W,1])
            h_slices = (slice(0,-self.window_size),
                        slice(-self.window_size,-self.window_size),
                        slice(-self.shift_size,None))
            w_slices = (slice(0,-self.window_size),
                        slice(-self.window_size,-self.window_size),
                        slice(-self.shift_size,None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:,h,w,:] = cnt
            
            img_mask = tf.convert_to_tensor(img_mask)            
            mask_windows = window_partition(img_mask,self.window_size)
            # print(f'in if {mask_windows.shape}')
            mask_windows = tf.reshape(mask_windows,shape=[
                -1,self.window_size*self.window_size
            ])
            # -1,1,49 - -1,49,1 => -1,49,49
            # print(f'in if {mask_windows.shape}')
            attn_mask = tf.expand_dims(mask_windows,axis=1) - tf.expand_dims(mask_windows,axis=2)
            
            attn_mask = tf.where(attn_mask!=0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask==0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask,
                            trainable=False)
            # print('in if',self.attn_mask.shape)
        else:
            self.attn_mask = None
            # print('in else')
        self.built = True
        # print(f'in build, attn_mask={self.attn_mask}')


    def get_config(self):
        config = super(SwinTransformerBlockLayer,self).get_config()
        config.update({"dim":self.dim,
                       "input_resolution":self.input_resolution,
                       "num_heads":self.num_heads,
                       "window_size":self.window_size,
                       "shift_size":self.shift_size,
                       "mlp_ratio":self.mlp_ratio,
                       "qkv_bias":self.qkv_bias,
                       "qk_scale":self.qk_scale,
                       "drop_rate":self.drop_rate,
                       "attn_drop_rate":self.attn_drop_rate,
                       "drop_path_prob":self.drop_path_prob,
                       })
        return config
    
    def call(self,x):
        # print(f'in call: {self.attn_mask}')
        H,W = self.input_resolution
        _,L,C = x.shape.as_list()
        assert L == H*W, 'input feature has wrong size.'

        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x,shape=[-1,H,W,C])

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(x,shift=[-self.shift_size,-self.shift_size],
                                axis=[1,2])
        else:
            shifted_x = x
        
        # partition windows
        x_windows = window_partition(shifted_x,self.window_size)
        x_windows = tf.reshape(x_windows, 
                        shape=[-1,self.window_size*self.window_size,C])

        # w-msa/sw-msa
        # print('在做注意力之前的',self.attn_mask.shape)
        attn_windows = self.attn(x_windows,mask=self.attn_mask)
        # print(f'做完msa之后的shape: {attn_windows.shape}')

        # merge windows
        attn_windows = tf.reshape(attn_windows,
                                shape=[-1,self.window_size,self.window_size,C])
        shifted_x = window_reverse(attn_windows,self.window_size,H,W,C)

        # reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(shifted_x,
                        shift=[self.shift_size,self.shift_size],
                        axis=[1,2])
        else:
            x = shifted_x
        
        x = tf.reshape(x,shape=[-1,H*W,C])
        
        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

class PatchMergingLayer(Layer):
    def __init__(self,input_resolution,dim,**kwargs):
        super(PatchMergingLayer,self).__init__(**kwargs)

        self.input_resolution = input_resolution
        self.dim = dim

        self.norm = LayerNormalization(epsilon=1e-5)
        self.reduction = Dense(2*self.dim,use_bias=False)


    def get_config(self):
        config = super(PatchMergingLayer,self).get_config()
        config.update({"input_resolution":self.input_resolution,
                       "dim":self.dim})

        return config

    def call(self,x):
        H,W = self.input_resolution
        B,L,C = x.shape.as_list()
        assert L==H*W, 'input feature has wrong size'
        assert H%2==0 and W%2==0, f'x size ({H}*{W}) are not even.'

        x = tf.reshape(x,shape=[-1,H,W,C])

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = tf.concat([x0, x1, x2, x3], axis=-1)
        x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C])

        x = self.norm(x)
        x = self.reduction(x)

        return x
    

class PatchEmbeddingLayer(Layer):
    def __init__(self,img_size=[224,224],patch_size=[4,4],
                embed_dims=96,**kwargs):
        super(PatchEmbeddingLayer,self).__init__(**kwargs)

        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dims = embed_dims

        patchs_resolution = [self.img_size[0]//self.patch_size[0],
                            self.img_size[1]//self.patch_size[1]]

        self.patchs_resolution = patchs_resolution
        self.num_patches = patchs_resolution[0] * patchs_resolution[1]

        self.proj = Conv2D(self.embed_dims,self.patch_size,
                            self.patch_size)
        

    def get_config(self):
        config = super(PatchEmbeddingLayer,self).get_config()
        config.update({"img_size":self.img_size,
                       "patch_size":self.patch_size,
                       "embed_dims":self.embed_dims,
                       "patchs_resolution":self.patchs_resolution,
                       "num_patches":self.num_patches})
        return config
    
    def call(self,x):
        _,H,W,C = x.shape.as_list()
        assert H==self.img_size[0] and W==self.img_size[1], \
            f'input img size ({H}*{W}) does not match model ({self.img_size[0]}*{self.img_size[1]}).'
        
        x = self.proj(x)
        _,h,w,c = x.shape.as_list()
        x = tf.reshape(x,shape=[-1,h*w,c])

        return x


if __name__ == '__main__':
    inputs = Input(shape=[224,224,3])
    # 做patch_embedding
    x = PatchEmbeddingLayer()(inputs)
    print(f'patch_embedding之后的输出大小(b,56*56,96): {x.shape}')

    # 经过一对swin transformer block
    # shift_size=0; num_heads=3; window_size=7; mlp_ratio=4
    x = SwinTransformerBlockLayer(96,[224//4,224//4],3,7,0,4)(x)
    print(f'经过一个没有shift的STB之后的输出大小(b,56*56,96): {x.shape}')
    # shift_size=3; num_heads=3; window_size=7; mlp_ratio=4
    x = SwinTransformerBlockLayer(96,[224//4,224//4],3,7,3,4)(x)
    print(f'经过一个经过shift的STB之后的输出大小(b,56*56,96): {x.shape}')

    # 经过patch_mering,h,w减倍,通道加倍
    x = PatchMergingLayer([224//4,224//4],96)(x)
    print(f'经过patch_merging之后的输出大小(b,28*28,96*2): {x.shape}')

patch_embedding之后的输出大小(b,56*56,96): (None, 3136, 96)
经过一个没有shift的STB之后的输出大小(b,56*56,96): (None, 3136, 96)
经过一个经过shift的STB之后的输出大小(b,56*56,96): (None, 3136, 96)
经过patch_merging之后的输出大小(b,28*28,96*2): (None, 784, 192)

;