Bootstrap

2021年12月提出的一种全局注意力机制方法 | 即插即用

Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions

paper:https://arxiv.org/pdf/2112.05561.pdf

摘要

        人们研究了多种注意机制,以提高各种计算机视觉任务的性能。然而,先前的方法忽略了保留通道和空间方面的信息以增强跨维度交互的意义。因此,论文提出了一种全局注意机制,通过减少信息的损失和提高全局特征的交互,提高深度神经网络的性能。引入了具有多层感知器的三维排列的通道注意力和一个卷积空间注意力模块。对CIFAR-100和ImageNet-1K上提出的图像分类任务机制的评估表明,该方法在ResNet和轻量级移动网络上都稳定地优于最近的几种注意机制。

论文背景     

        近年来,注意机制的提高在多种应用中,并引起了研究兴趣。使用编码器-解码器残余注意模块对特征图进行细化,以获得更好的性能。SENet是第一个使用通道注意和通道级特征融合来抑制不重要的通道。然而,它在抑制不重要的像素方面效率较低。后来的CBAM注意机制同时考虑了空间维度和通道维度。然而,它们都忽略了通道-空间的相互作用,从而失去了跨维信息。考虑到跨维度交互作用的重要性,TAM通过利用每对三维通道、空间宽度和空间高度之间的注意权重来提高效率。然而,注意操作仍然每次应用于两个维度,而不是所有三个维度。为了放大跨维度的交互作用。本文研究注意机制跨越空间通道的维度,提出了一种“全局”注意机制,保留信息以放大“全局”跨维度的相互作用,在所有三个维度上捕捉重要特征的注意力机制。命名为全局注意机制(GAM)。

论文主要思想

        设计一种注意力机制,以减少信息损失和放大全局维度交互特征。通过采用了CBAM中的顺序通道-空间注意机制,并重新设计了子模块。

        给定输入特征图F_{1} , 中特征图 F_{2}和输出特征图F_{3}定义为:

F_{2}=M_{c}(F_{1})\bigotimes F_{1}

F_{3}=M_{s}(F_{2})\bigotimes F_{2}

其中M_{c}M_{s}分别表示通道注意力和空间注意力特征图谱;\bigotimes表示按元素进行乘法操作

通道注意子模块使用三维排列来保留三维信息。然后,它用一个两层的MLP(多层感知器)来放大了跨维的通道-空间依赖关系。(MLP是一种具有还原比r的编码器-解码器结构,与BAM相同)通道注意子模块如图2所示。

 在空间注意子模块中,为了关注空间信息,使用了两个卷积层来进行空间信息融合。与BAM相同也使用相同的通道注意子模块的减少比r。同时,由于最大池化操作减少了信息,并产生了负向贡献。在该模块中删除了池化操作,以进一步保留特性映射。因此,空间注意模块有时会显著增加参数的数量。为了防止参数的显著增加,采用了带有通道混洗的组卷积。在ResNet50中,没有群卷积的空间注意子模块如图3所示。

Keras实现 

以下是根据论文实现的keras版本(支持Tensorflow1.x)。


class GAMAttention(Layer):
    def __init__(self, filters, rate=4):
        super(GAMAttention, self).__init__()
        self.input_dims = int(filters)
        self.out_dims = int(filters)
        self.reduce_dims = int(filters / rate)
        self.rate = rate

        # channel attention
        self.channel_linear_reduce = Dense(self.reduce_dims)
        self.channel_activation = Activation('relu')
        self.channel_linear = Dense(self.input_dims)

        # spatial attention
        self.spatial_con_reduce = Conv2D(self.reduce_dims, kernel_size=7, padding='same')
        self.spatial_bn_reduce = BatchNormalization()
        self.spatial_activation = Activation('relu')
        self.spatial_con = Conv2D(self.out_dims, kernel_size=7, padding='same')
        self.spatial_bn = BatchNormalization()

        self.in_shape = None

    def build(self, input_shape):
        assert input_shape[-1] == self.out_dims, 'input filters must equal to input of channel'
        self.in_shape = input_shape
        # channel attention
        self.channel_linear_reduce.build((input_shape[0], input_shape[1]*input_shape[2], input_shape[3]))
        self.trainable_weights += self.channel_linear_reduce.trainable_weights
        self.channel_linear.build((input_shape[0], input_shape[1]*input_shape[2], self.reduce_dims))
        self.trainable_weights += self.channel_linear.trainable_weights

        # spatial attention
        self.spatial_con_reduce.build(input_shape)
        self.trainable_weights += self.spatial_con_reduce.trainable_weights
        self.spatial_bn_reduce.build((input_shape[0], input_shape[1], input_shape[2], self.reduce_dims))
        self.trainable_weights += self.spatial_bn_reduce.trainable_weights
        self.spatial_con.build((input_shape[0], input_shape[1], input_shape[2], self.reduce_dims))
        self.trainable_weights += self.spatial_con.trainable_weights
        self.spatial_bn.build(input_shape)
        self.trainable_weights += self.spatial_bn.trainable_weights

        super(GAMAttention, self).build(input_shape)  # Be sure to call this at the end

    def call(self, f1, **kwargs):

        # channel attention
        tmp = Reshape((-1, self.input_dims))(f1)
        tmp = self.channel_linear_reduce(tmp)
        tmp = self.channel_activation(tmp)
        tmp = self.channel_linear(tmp)
        mc = Reshape(self.in_shape[1:])(tmp)
        f2 = mc * f1

        # spatial attention
        tmp = self.spatial_con_reduce(f2)
        tmp = self.spatial_bn_reduce(tmp)
        tmp = self.spatial_activation(tmp)
        tmp = self.spatial_con(tmp)
        ms = self.spatial_bn(tmp)
        f3 = ms * f2

        return f3

声明:本内容来源网络,版权属于原作者,图片来源原论文。如有侵权,联系删除。

创作不易,欢迎大家点赞评论收藏关注!(想看更多最新的模型压缩文献欢迎关注浏览我的博客)

;