Bootstrap

【深度学习】Pytorch:在 ResNet 中加入注意力机制

在这篇教程中,我们将介绍如何在 ResNet 网络中加入注意力机制模块。我们将通过对标准 ResNet50 进行改进,向网络中添加两个自定义的注意力模块,并展示如何实现这一过程。

为什么要加入注意力机制

注意力机制可以帮助神经网络专注于图像中重要的特征区域,从而提高模型的性能。在卷积神经网络中,加入注意力机制能够有效增强特征提取能力,减少冗余信息的干扰,尤其在处理复杂图像时,能够提升网络的表现。

在本教程中,我们将使用一种通用的注意力模块,您可以根据需求自行替换或改进该模块。

代码实现

导入依赖

我们需要以下 PyTorch 库来构建网络:

import torch
import torch.nn as nn
from torchvision import models

定义注意力模块

首先,我们需要定义一个注意力模块。这里我们使用了一个简单的通道注意力机制(如 SE 模块、CBAM 模块等),你可以根据需求选择不同类型的注意力模块。

假设我们已经有一个注意力模块类(AttentionModule),它的结构可以像下面这样:

class AttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(AttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 16, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels // 16, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention = self.conv1(x)
        attention = self.relu(attention)
        attention = self.conv2(attention)
        attention = self.sigmoid(attention)
        return x * attention

这段代码定义了一个简单的注意力模块。它通过两个卷积层和一个 Sigmoid 函数来生成一个通道注意力映射,并通过该映射加权输入特征图。

构建 ResNet 与注意力机制集成的模型

现在我们将创建一个新的模型类 ResNetWithAttention,该模型继承自 nn.Module,并将注意力模块插入到 ResNet 的关键位置。在这个示例中,我们将注意力模块插入到网络的卷积层输出之后,并在最后一层卷积层后再次插入。

class ResNetWithAttention(nn.Module):
    def __init__(self, attention_cls, pretrained=True):
        super(ResNetWithAttention, self).__init__()

        # 使用预训练的 ResNet50
        self.base_model = models.resnet50(pretrained=pretrained)

        # 创建注意力模块
        self.attention_layer1 = attention_cls(64)  # 第一层卷积后
        self.attention_layer2 = attention_cls(2048)  # 最后一层卷积后

    def forward(self, x):
        # ResNet50的前向传播过程
        x = self.base_model.conv1(x)  # 初始卷积层
        x = self.base_model.bn1(x)  # 批归一化
        x = self.base_model.relu(x)  # 激活函数

        # 第一个注意力模块:第一层卷积后
        x = self.attention_layer1(x)

        # 最大池化层
        x = self.base_model.maxpool(x)

        # ResNet的残差层
        x = self.base_model.layer1(x)
        x = self.base_model.layer2(x)
        x = self.base_model.layer3(x)
        x = self.base_model.layer4(x)

        # 第二个注意力模块:最后一层卷积后
        x = self.attention_layer2(x)

        # 平均池化
        x = self.base_model.avgpool(x)

        # 展平并通过全连接层
        x = torch.flatten(x, 1)
        x = self.base_model.fc(x)
        
        return x

在这个模型中,我们通过 attention_cls 参数动态地将任何类型的注意力模块传入模型。模型首先使用基础的 ResNet50 结构,之后我们将自定义的注意力模块应用到两个关键位置:一个是在第一层卷积之后,另一个是在最后的卷积层之后。

训练模型

使用该模型的训练过程与标准的 ResNet 模型相同。你可以像使用普通的 ResNet 模型一样训练和评估 ResNetWithAttention。下面是训练的一般流程:

# 示例:初始化模型并进行训练
attention_cls = AttentionModule  # 可以替换为其他类型的注意力模块
model = ResNetWithAttention(attention_cls)

# 选择优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 假设 train_loader 是数据加载器
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

可加入的注意力模块

通道注意力模块

  • SE (Squeeze-and-Excitation) 模块:最经典的通道注意力模块,使用全局平均池化后生成通道级注意力,通过全连接层建模通道之间的关系。

    class SEBlock(nn.Module):
        def __init__(self, in_channels, reduction=16):
            super(SEBlock, self).__init__()
            # 定义第一个全连接层,将输入通道数压缩为 in_channels // reduction
            self.fc1 = nn.Linear(in_channels, in_channels // reduction)
            # 定义第二个全连接层,将通道数恢复为原始输入通道数
            self.fc2 = nn.Linear(in_channels // reduction, in_channels)
            # 定义Sigmoid激活函数,用于生成注意力权重
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            # 获取输入张量的 batch_size 和 channels
            batch_size, channels, _, _ = x.size()
            # 对输入张量的空间维度(高度和宽度)进行全局平均池化
            squeeze = torch.mean(x, dim=(2, 3))
            # 通过第一个全连接层进行通道压缩
            squeeze = self.fc1(squeeze)
            # 通过ReLU激活函数和第二个全连接层进行通道扩展
            squeeze = self.fc2(F.relu(squeeze))
            # 使用Sigmoid生成注意力权重,并调整形状以匹配输入张量的维度
            attention = self.sigmoid(squeeze).view(batch_size, channels, 1, 1)
            # 将注意力权重应用到输入张量上,进行通道加权
            return x * attention
    
  • ECA (Efficient Channel Attention) 模块:通过 1D 卷积建模通道间的依赖关系,减少了计算量,提升了效率

    class ECABlock(nn.Module):
        def __init__(self, channels, kernel_size=3):
            super(ECABlock, self).__init__()
            # 定义1D卷积层,用于学习通道间的注意力权重
            self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size // 2, bias=False)
    
        def forward(self, x):
            # 获取输入张量的 batch_size 和 channels
            batch_size, channels, _, _ = x.size()
            # 对输入张量的空间维度(高度和宽度)进行全局平均池化,并调整形状
            y = F.adaptive_avg_pool2d(x, 1).view(batch_size, channels, 1)
            # 调整形状以适配1D卷积的输入格式
            y = y.view(batch_size, 1, channels)
            # 通过1D卷积层学习通道间的注意力权重
            y = self.conv(y)
            # 使用Sigmoid激活函数生成注意力权重
            y = torch.sigmoid(y)
            # 将注意力权重应用到输入张量上,进行通道加权
            return x * y.view(batch_size, channels, 1, 1).expand_as(x)
    

空间注意力模块

  • CBAM (Convolutional Block Attention Module) 模块:结合了通道注意力和空间注意力,首先进行通道注意力加权,然后通过空间卷积生成空间注意力。

    class CBAM(nn.Module):
        def __init__(self, in_channels, reduction=16):
            super(CBAM, self).__init__()
            # 通道注意力模块(SEBlock),用于学习通道间的注意力权重
            self.channel_attention = SEBlock(in_channels, reduction)
            # 空间注意力模块,使用1x1卷积核学习空间注意力权重
            self.spatial_attention = nn.Conv2d(2, 1, kernel_size=7, padding=3)
    
        def forward(self, x):
            # 应用通道注意力模块,对输入特征进行通道加权
            x = self.channel_attention(x)
            # 计算输入特征在通道维度上的平均值
            avg_out = torch.mean(x, dim=1, keepdim=True)
            # 计算输入特征在通道维度上的最大值
            max_out, _ = torch.max(x, dim=1, keepdim=True)
            # 将平均值和最大值拼接在一起
            spatial_out = torch.cat([avg_out, max_out], dim=1)
            # 通过空间注意力模块学习空间注意力权重
            spatial_out = self.spatial_attention(spatial_out)
            # 使用Sigmoid激活函数生成空间注意力权重
            spatial_attention = torch.sigmoid(spatial_out)
            # 将空间注意力权重应用到输入特征上,进行空间加权
            return x * spatial_attention
    
  • Coordinate Attention 模块:通过空间坐标信息提升特征建模能力,增强空间特征表达。

    class CoordinateAttention(nn.Module):
        def __init__(self, in_channels, reduction=16):
            super(CoordinateAttention, self).__init__()
            # 定义1x1卷积层,用于压缩通道数
            self.fc = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1)
            # 定义1x1卷积层,用于恢复通道数
            self.fc_out = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1)
    
        def forward(self, x):
            # 获取输入张量的 batch_size、channels、height 和 width
            batch_size, channels, height, width = x.size()
            # 对输入张量的空间维度(高度和宽度)进行全局平均池化
            avg_out = torch.mean(x, dim=[2, 3], keepdim=True)
            # 对输入张量的空间维度(高度和宽度)进行全局最大池化
            max_out = torch.amax(x, dim=[2, 3], keepdim=True)
            # 通过1x1卷积层压缩通道数
            avg_out = self.fc(avg_out)
            max_out = self.fc(max_out)
            # 通过1x1卷积层恢复通道数
            avg_out = self.fc_out(avg_out)
            max_out = self.fc_out(max_out)
            # 将平均池化和最大池化的结果相加,并应用到输入张量上
            out = x * (avg_out + max_out)
            return out
    

双重注意力模块

  • Dual Attention 模块:结合了通道和空间的双重注意力机制,增强了特征的表征能力。

    class DualAttentionBlock(nn.Module):
        def __init__(self, in_channels, reduction=16):
            super(DualAttentionBlock, self).__init__()
            # 通道注意力模块(SEBlock),用于学习通道间的注意力权重
            self.channel_attention = SEBlock(in_channels, reduction)
            # 空间注意力模块(CBAM),用于学习空间上的注意力权重
            self.spatial_attention = CBAM(in_channels, reduction)
    
        def forward(self, x):
            # 应用通道注意力模块,对输入特征进行通道加权
            x = self.channel_attention(x)
            # 应用空间注意力模块,对输入特征进行空间加权
            x = self.spatial_attention(x)
            return x
    

全局依赖建模模块

  • Non-local 模块:通过自注意力机制建模全局依赖关系,提升对长距离特征的建模能力。

    class NonLocalBlock(nn.Module):
        def __init__(self, in_channels):
            super(NonLocalBlock, self).__init__()
            # 输入通道数
            self.in_channels = in_channels
            # 中间通道数,通常为输入通道数的一半
            self.inter_channels = in_channels // 2
            
            # 定义1x1卷积层,用于生成查询(query)特征
            self.query_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
            # 定义1x1卷积层,用于生成键(key)特征
            self.key_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
            # 定义1x1卷积层,用于生成值(value)特征
            self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
            
            # 定义Softmax函数,用于计算注意力权重
            self.softmax = nn.Softmax(dim=-1)
    
        def forward(self, x):
            # 获取输入张量的 batch_size、通道数、高度和宽度
            batch_size, C, H, W = x.size()
            # 通过查询卷积层生成查询特征,并调整形状
            query = self.query_conv(x).view(batch_size, self.inter_channels, -1)
            # 通过键卷积层生成键特征,并调整形状
            key = self.key_conv(x).view(batch_size, self.inter_channels, -1)
            # 通过值卷积层生成值特征,并调整形状
            value = self.value_conv(x).view(batch_size, C, -1)
            
            # 计算查询特征和键特征的相似度(亲和矩阵)
            affinity = torch.bmm(query.transpose(1, 2), key)
            # 使用Softmax计算注意力权重
            attention = self.softmax(affinity)
            # 将注意力权重应用到值特征上,得到加权输出
            out = torch.bmm(value, attention.transpose(1, 2))
            # 调整输出形状以匹配输入张量的维度
            out = out.view(batch_size, C, H, W)
            
            # 将加权输出与输入张量相加,实现残差连接
            return out + x
    
  • Attention U-Net 模块:在 U-Net 结构中引入注意力模块,适用于图像分割任务,能够自适应地选择重要区域进行特征增强。

    class AttentionGate(nn.Module):
        def __init__(self, in_channels):
            super(AttentionGate, self).__init__()
            # 定义门控通道数和中间通道数
            gating_channels = in_channels
            inter_channels = in_channels // 2
          
            # 定义1x1卷积层,用于处理输入特征
            self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1)
            # 定义1x1卷积层,用于处理门控特征
            self.conv2 = nn.Conv2d(gating_channels, inter_channels, kernel_size=1)
            # 定义1x1卷积层,用于生成注意力权重
            self.psi = nn.Conv2d(inter_channels, 1, kernel_size=1)
            # 定义Sigmoid激活函数,用于生成注意力权重
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            # 门控信号与输入特征相同
            gating = x
            # 对输入特征进行1x1卷积
            x1 = self.conv1(x)
            # 对门控信号进行1x1卷积
            x2 = self.conv2(gating)
            # 将两个卷积结果相加,并通过ReLU激活函数
            attention = self.sigmoid(self.psi(F.relu(x1 + x2)))
            # 将注意力权重应用到输入特征上,进行加权
            return x * attention
    

总结

通过将注意力模块集成到 ResNet 中,我们能够增强模型对重要特征的关注,从而提高性能。你可以根据需要选择不同的注意力机制,并在模型中任意位置插入这些模块。

悦读

道可道,非常道;名可名,非常名。 无名,天地之始,有名,万物之母。 故常无欲,以观其妙,常有欲,以观其徼。 此两者,同出而异名,同谓之玄,玄之又玄,众妙之门。

;