Bootstrap

神经网络(五):U2Net图像分割网络


  参考论文:U2-Net: Going deeper with nested U-structure for salient object detection
  这篇文章基于显著目标检测任务提出,显著目标检测是指将图像中最吸引人的目标或者区域分割出来,因此只有前景和背景两个类别,相当于语义分割中的二分类任务。例如,下图中展示了三张图片进行显著目标检测后的结果:
在这里插入图片描述
其中,白色区域代表前景 ,即最吸引人的目标或区域,而黑色区域代表背景。
  在 U 2 N e t U^2Net U2Net被提出前,显著目标检测领域主要面临两个问题:

  • 1.现有的SOD网络大多基于当时已有的网络架构(主干网络)进行深度特征的提取,如 A l e x N e t 、 V G G 、 R e s N e t 、 R e s N e X t AlexNet、VGG、ResNet、ResNeXt AlexNetVGGResNetResNeXt等。这些网络最初是为图像分类设计的,它们提取代表语义含义的特征(如某一具体的猫、狗),而非局部细节和全局对比度信息(SOD的主要任务是将图像划分为前景与背景,而不注重某一具体特征的提取),使得这些模型在进行SOD任务时效率低下。
  • 2.当时的SOD网络架构不断通过向现有的主干网络中添加特征聚合模块以提取多级显著特征,使得模型过于复杂。且这些图像分类模型往往通常通过牺牲特征图的高分辨率来实现更深的架构,即特征图在早期阶段会被缩小到较低的分辨率,如ResNet和DenseNet会使用步长为2的卷积和步长为2的最大池化将特征图大小减小到输入图的四分之一。但是,高分辨率在图像分割中有着重要作用,这也使得这些模型并不适用SOD任务。

  为解决上述问题,提出了 U 2 N e t U^2Net U2Net网络结构:

  • 一种两级嵌套的U形结构,专为SOD 设计,无需使用任何来自图像分类的预训练主干。
  • 提出U型残差块RSU,能在不降低特征图分辨率的情况下提取阶段内多尺度特征。

一、网络结构

   U 2 N e t U^2Net U2Net本质上是 U N e t UNet UNet网络的嵌套,该网络的整体结构与 U N e t UNet UNet网络几乎相同,但所使用的上采样、下采样模块变成了小型的 U N e t UNet UNet网络。 U 2 − N e t U^2-Net U2Net网络的核心就是这些作为模块的小型 U N e t UNet UNet网络,并将其起名为 U 型残差块 ( R e S i d u a l U − b l o c k , R S U ) U型残差块(ReSidual U-block,RSU) U型残差块(ReSidualUblockRSU)。网络结构如下:
在这里插入图片描述
这些模块其实可以分为两种, E n c o d e r 1 − E n c o d e r 4 、 D e c o d e r 1 − D e c o d e r 4 Encoder1-Encoder4、Decoder1-Decoder4 Encoder1Encoder4Decoder1Decoder4采用的是同一种结构的残差块RSU-L,只不过深度不同,而Encoder5、Encoder6、Decoder5 采用的是另一种结构的残差块RSU-4F。整体流程可概况为:

  • Encoder阶段:每通过一个模块后都会两倍下采样,使用的是torch.nn.MaxPool2d
  • Decoder阶段:每通过一个模块后都会两倍上采用,使用的是torch.nn.functional.interpolate()
  • 跳跃链接:与 U N e t UNet UNet网络思路相同,将编码器的输出与解码器输出的特征图进行拼接,最后得到分割后的图像。

1.1第一种block结构

  本地和全局上下文信息对于显著对象检测和图像分割任务都非常重要,现代CNN网络设计中VGG、ResNet、DenseNet 等,一般使用1x1或3x3的小型卷积核提取特征。但在SOD任务中,由于它们的感受野太小而无法捕捉全局信息,使得浅层的输出特征图仅包含局部特征。在下图(图 ( a ) − ( c ) (a)-(c) (a)(c))中给出了具有小感受野的典型现有卷积块。为从浅层获得高分辨率特征图的更多全局信息,最直接的想法是扩大感受野,图 ( e ) (e) (e)是一种双向消息传递模块(见论文ieee),它试图通过使用扩张卷积扩大感受野来提取局部和非局部特征,以原始分辨率对输入特征图进行多次扩张卷积(尤其是在早期阶段)需要太多的计算和内存资源。
  为解决上述问题,本文提出了RSU模块(图 ( e ) (e) (e),L表示RSU的深度,图中L=7):

  • 输入卷积层:局部特征提取的普通卷积层,用于将尺寸为 ( H , W , C i n ) (H,W,C_{in}) (H,W,Cin)输入特征图x添加到中间映射 F 1 ( x ) F_1(x) F1(x)中。
  • 高度为L、类似 U N e t UNet UNet的对称编码器-解码器结构(RSU模块):以 F 1 ( x ) F_1(x) F1(x)作为输入,学习多尺度上下文信息。较大的 L 会生成更深的RSU、更多的池化操作、更大的感受野范围以及更丰富的局部和全局特征。配置此参数可以从具有任意空间分辨率的输入特征图中提取多尺度特征。从逐渐下采样的特征图中提取多尺度特征,并通过渐进式上采样、连接和卷积,将特征图还原为高分辨率特征图。此设计可减轻因使用大比例的直接上采样而导致的精细细节损失,也更加适用于SOD任务。用 U ( x ) U(x) U(x)表示RSU模块,则提取的信息可表示为 U ( F 1 ( x ) ) U(F_1(x)) U(F1(x))
  • 残差连接:通过求和操作融合局部特征和多尺度特征,可表示为 F 1 ( x ) + U ( F 1 ( x ) ) F_1(x)+U(F_1(x)) F1(x)+U(F1(x))

在这里插入图片描述

   R S U − 7 RSU-7 RSU7真实结构如下图所示:
在这里插入图片描述
  回到 U 2 − N e t U^2-Net U2Net结构,该RSU的使用场景有:

  • Encoder1 和 Decoder1 采用的是 RSU-7 结构。
  • Encoder2 和 Decoder2 采用的是 RSU-6 结构。
  • Encoder3 和 Decoder3 采用的是 RSU-5 结构。
  • Encoder4 和 Decoder4 采用的是 RSU-4 结构。

可见,相邻 block 相差一次下采样和一次上采样,例如 RSU-6 相比于 RSU-7 少了一个下采样卷积和上采样卷积部分,RSU-7 是下采样 32 倍和上采样 32 倍,RSU-6 是下采样 16 倍和上采样 16 倍。代码实现如下:


import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):    #实现conv2d+BN+ReLU操作                                                      
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        # dilation用于实现空洞卷积
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
 
def _upsample_like(src,tar):
    src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     
    return src
 
 
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   
 
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2
 
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             
 
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
 
        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
 
        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
 
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
 
        hx = x
        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)
 
        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)
 
        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)
 
        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)
 
        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)
 
        hx6 = self.rebnconv6(hx)
 
        hx7 = self.rebnconv7(hx6)                                  
 		#实现残差连接
        hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
 
        hx6dup = _upsample_like(hx6d,hx5)
        hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
 
        hx5dup = _upsample_like(hx5d,hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
 
        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
 
        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
 
        hx2dup = _upsample_like(hx2d,hx1)
 
        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
 
        return hx1d + hxin

在这里插入图片描述

1.2第二种block结构

  数据经过 E n 1 − E n 4 En_1-En4 En1En4下采样处理后对应特征图的分辨率就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息。为保留上下文信息,在 E n c o d e r 5 、 E n c o d e r 6 、 D e c o d e r 5 Encoder5、Encoder6、Decoder5 Encoder5Encoder6Decoder5中将原始RSU中的上采样、下采样结构换成了空洞卷积操作,从而得到了 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本。此时 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。在这里插入图片描述
需要注意,在 E n c o d e r 5 Encoder5 Encoder5中特征图大小已经到了18*18,非常小(也因此不需要再下采样),故采用了空洞卷积操作,目的在不改变特征图大小的情况下增大感受野。故在代码中使用了dalition=2、4、8 E n c o d e r 6 、 D e c o d e r 5 Encoder6、Decoder5 Encoder6Decoder5同理。这一特点在原图中显示为使用了虚线构成的长方体块。

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
 
class REBNCONV(nn.Module):                                                          #CBL
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
        
### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()
 
        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
 
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
 
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
 
    def forward(self,x):
        hx = x

        hxin = self.rebnconvin(hx)
 
        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)
 
        hx4 = self.rebnconv4(hx3)
 
        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
 
        return hx1d + hxin

1.3特征图融合模块

  在通过编码、解码器的运算后,最后通过特征图融合模块(红框标出)将 D e 1 、 D e 2 、 D e 3 、 D e 4 、 D e 5 、 E n 6 De_1、De_2、De_3、De_4、De_5、En_6 De1De2De3De4De5En6模块的输出分别通过一个3x3的卷积层(卷积层的卷积核个数均为1),并通过双线性插值将得到的特征图还原回输入图像的大小,之后将得到的6个特征图进行拼接(Concatenation),最后再经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的图像。
在这里插入图片描述

1.4损失函数

   U 2 N e t U^2Net U2Net使用多监督算法构建损失函数。网络输出不仅仅包含最终特征图,还包含前面6个不同尺度的特征图,即,不仅要监督网络输出,还要监督中间融合特征图。 损失函数计算公式:
在这里插入图片描述
其中, M = 1 , 2 , 3 , . . . , 6 M=1,2,3,...,6 M=1,2,3,...,6 l s i d e m l_{side}^{m} lsidem表示特征图 S u p 1 、 S u p 2 、 . . . 、 S u p 6 Sup1、Sup2、...、Sup6 Sup1Sup2...Sup6的损失,而 l f u s e l_{fuse} lfuse表示最终特征图的损失, w w w则表示两种损失的权重参数(论文给出的源码中全为1)。 l s i d e l_{side} lside l f u s e l_{fuse} lfuse采用二值交叉熵(standard binary cross-entropy)进行计算:
在这里插入图片描述
其中, ( r , c ) (r,c) (r,c)表示像素坐标值, ( H , W ) (H,W) (H,W)表示图像高度和宽度, P G ( r , c ) P_{G(r,c)} PG(r,c)表示标签图像素灰度值, P S ( r , c ) P_{S(r,c)} PS(r,c)表示预测的图像素灰度值。

1.5总体网络架构

在这里插入图片描述
   U 2 N e t U^2Net U2Net主要由三部分组成:

  • 一个六级编码器:在 E n 1 、 E n 2 、 E n 3 、 E n 4 En_1、En_2、En_3、En_4 En1En2En3En4中分别使用 R S U 7 、 R S U 6 、 R S U 5 、 R S U 4 RSU7、RSU6、RSU5、RSU4 RSU7RSU6RSU5RSU4 L L L通常根据输入特征图的空间分辨率进行配置。对于高宽较大的特征图,使用更大的 L L L来捕获更多的大比例度信息。 E n 5 En_5 En5 E n 6 En_6 En6中特征图的分辨率相对较低,进一步降低这些特征图的采样会导致有用的上下文丢失。因此,在 E n 5 En_5 En5 E n 6 En_6 En6阶段,都使用 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本,其用膨胀卷积代替池化和上采样操作这意味着 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。
  • 一个五级解码器: D e 5 De_5 De5阶段同样使用 R S U − 4 F RSU-4F RSU4F,并且每个解码器都使用前一阶段的上采样特征图和来自其对称编码器阶段的特征图的串联作为输入(跳跃连接)。
  • 特征图融合模块:将六个侧输出显著性特征图上采样到输入图像的尺寸,之后使用融合操作,并通过1x1卷积层和sigmoid函数生成最终的显著性特征图。

  研究中将3320320的图像裁剪为3288288大小输入模型,最终得到1288288的图像分割结果(二值图像):
在这里插入图片描述

1.6代码汇总

import torch
import torch.nn as nn
import torch.nn.functional as F

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

1.7普通残差块与RSU对比

在这里插入图片描述
  普通残差块的操作可概况为 H ( x ) = F 2 ( F 1 ( x ) ) + x H(x)=F_2(F_1(x))+x H(x)=F2(F1(x))+x,其中, F 1 、 F 2 F_1、F_2 F1F2代表权重层,此处设为卷积运算。RSU 和残差块的区别在于,RSU 用类似 U N e t UNet UNet的结构替换了普通卷积运算,并将原始特征替换为卷积层提取的局部特征信息: H R S U ( x ) = U ( F 1 ( x ) ) + F 1 ( x ) H_{RSU}(x)=U(F_1(x))+F_1(x) HRSU(x)=U(F1(x))+F1(x),其中 U U U表示多层U型结构。这种设计使网络能够直接从每个残差块中提取来自多个尺度的特征(设置不同的深度值L即可)。更值得注意的是, U 结构导致模型的计算开销很小,因为大多数操作都应用于下采样的特征图。下图中给出了RSU中 F 1 、 U F_1、U F1U的含义:
在这里插入图片描述

  残差块性能比较:
在这里插入图片描述

  • PLN:普通卷积块。
  • RES:残差块。
  • DSE:密集块。
  • INC:初始块。
  • RSU:U型残差块。

二、代码复现

https://github.com/xuebinqin/U-2-Net/tree/master

;