Bootstrap

图像分割:Unet的pytorch代码实现(二)Unet模型代码

模型结构

Unet模型是一个U型结构的模型,包含编码器,解码器两部分。编码器是对图像下采样和提取特征,解码器是对图像上采样并恢复图像空间细节。

Unet模型结构

U-Net模型的编码器部分通过一系列的卷积层和池化层逐步减小特征图的尺寸,同时增加特征通道的数量,从而捕捉到更深层次的语义信息。每个编码块通常由两个3x3的卷积层(提取特征)和一个2x2的最大池化层(下采样)组成。

在解码器部分,U-Net通过跳跃连接将编码器中的特征图与解码器中的特征图相结合,这样可以帮助保留更多的空间信息并提高定位准确性。解码器中的每个块通常包括一个上采样操作(如反卷积或插值法),用于逐步恢复特征图的尺寸,然后是两个3x3的卷积层。并通过拼接操作与编码器相应层级的特征图进行结合。

下采样部分

下采样是通过减少数据的空间分辨率或尺寸来降低数据量的过程

本文使用一个核大小为3*3的卷积操作进行下采样,将特征图的大小减半。

例如有一个4*4的矩阵,经过 padding=1 后的操作得到一个6*6的矩阵。

然后使用一个3*3卷积核提取特征。

当步长为2时,卷积后的矩阵大小就为2*2的矩阵,实现了下采样。

上采样部分

本文使用最近邻插值进行上采样:这种方法是最简单的插值技术,它直接取距离目标像素最近的源像素值作为目标像素的值。优点是计算简单、速度快,但缺点是可能导致图像质量下降,因为这种插值方式会在图像中引入明显的锯齿效应(锯齿效应是一种在图像处理和计算机图形学中常见的视觉现象,表现为物体边缘出现不规则、锯齿状的伪影)。

先使用一个1*卷积核对特征图的通道数进行减半,然后再使用插值法进行上采样。

卷积部分

本文的模型包含两个卷积层、两个批量归一化层、两个Dropout层和一个PReLU激活函数。这个模块用于构建卷积神经网络中的卷积块,用于特征提取。

跳跃连接部分

编码器部分的特征图先使用一个1*1卷积核变换通道数,然后使用torch.cat()与解码器部分的特征图连接起来。

代码实现

import torch
import torch.nn as nn
from torch.nn import functional

下采样部分

# 下采样,使用卷积进行,图像缩小2倍
class DownSample(nn.Module):
    def __init__(self, channels, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.Down_block = nn.Sequential(
            # 一个3*3卷积。padding=1, stride = 2:保证卷积后特征图大小减半
            nn.Conv2d(channels, channels, 3, padding=1, stride=2,
                      padding_mode='reflect', bias=False),  # 卷积
            nn.BatchNorm2d(channels),  # 批归一化
            nn.PReLU()
        )

    def forward(self, input):
        # print(input.shape, self.Down_block(input).shape, "Down")
        return self.Down_block(input)

上采样部分

# 上采样 使用插值法,并1*1卷积将通道数减半,进行上采样,图像扩大2倍
class UpSample(nn.Module):
    def __init__(self, channels, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.Conv_1_1 = nn.Conv2d(channels, channels // 2, 1, stride=1, bias=False)

    def forward(self, input):
        # 最近邻插值法
        up_data = functional.interpolate(input, scale_factor=2, mode='nearest')
        conv_1 = self.Conv_1_1(up_data)
        # print(up_data.shape, conv_1.shape, input.shape)
        return conv_1

卷积部分

# 卷积块 用于提取特征
class Conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, Dropout=0.3, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.block = nn.Sequential(
            # 3*3卷积块,填充1,步长1,填充模式为反射
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding=1, padding_mode='reflect',
                      kernel_size=3, bias=False),
            # 批归一化
            nn.BatchNorm2d(out_channels),
            # Dropout
            nn.Dropout(Dropout),
            # 激活函数
            nn.PReLU(),

            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, padding=1, padding_mode='reflect',
                      kernel_size=3, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(Dropout),
            nn.PReLU(),
        )

    def forward(self, input):
        return self.block(input)

Unet部分

这里包含了跳跃连接

# Unet网络结构,包含编码器和解码器
class Unet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 编码器部分
        self.c1 = Conv_block(3, 64)
        self.d1 = DownSample(64)
        self.c2 = Conv_block(64, 128)
        self.d2 = DownSample(128)
        self.c3 = Conv_block(128, 256)
        self.d3 = DownSample(256)
        self.c4 = Conv_block(256, 512)
        self.d4 = DownSample(512)
        self.c5 = Conv_block(512, 1024)
        # 解码器部分
        self.u1 = UpSample(1024)
        # 对c1输出的特征图1*1卷积, 用于跳跃连接
        self.c1_u = Conv_block(1024, 512)
        self.u2 = UpSample(512)
        # 对c2输出的特征图1*1卷积, 用于跳跃连接
        self.c2_u = Conv_block(512, 256)
        self.u3 = UpSample(256)
        # 对c3输出的特征图1*1卷积, 用于跳跃连接
        self.c3_u = Conv_block(256, 128)
        self.u4 = UpSample(128)
        # 对c4输出的特征图1*1卷积, 用于跳跃连接
        self.c4_u = Conv_block(128, 64)
        self.c_out = nn.Conv2d(64, 1, 1, 1
                               , bias=False)
        # 激活函数
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        S_1 = self.c1(input)
        S_2 = self.c2(self.d1(S_1))
        S_3 = self.c3(self.d2(S_2))
        S_4 = self.c4(self.d3(S_3))
        S_5 = self.c5(self.d4(S_4))

        # print(S_4.shape, self.u1(S_5).shape)
        S_6 = self.c1_u(torch.cat((self.u1(S_5), S_4), dim=1))
        S_7 = self.c2_u(torch.cat((self.u2(S_6), S_3), dim=1))
        S_8 = self.c3_u(torch.cat((self.u3(S_7), S_2), dim=1))
        S_9 = self.c4_u(torch.cat((self.u4(S_8), S_1), dim=1))
        S_out = self.c_out(S_9)
        return self.sigmoid(S_out)

验证模型准确性

def main():
    x = torch.randn(1, 3, 512, 512)
    net = Unet()
    net = net
    print(net(x).shape)


if __name__ == "__main__":
    main()

 能正常输出

torch.Size([1, 1, 512, 512])

;