完整源码位于文章底部
1.ASPP是什么
ASPP:Atrous Spatial Pyramid Pooling,空洞空间卷积池化金字塔。
简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。
2.网络结构
对于输入图像input:
- 用一个1×1的卷积对input进行降维
- 用一个padding为6,dilation为6,核大小为3×3的卷积层进行卷积
- 用一个padding为12,dilation为12,核大小为3×3的卷积层进行卷积
- 用一个padding为18,dilation为18,核大小为3×3的卷积层进行卷积
- 用一个尺寸为input大小的池化层将input池化为1×1,再用一个1×1的卷积进行降维,最后上采样回原始输入大小
最后将这五层的输出进行concat,并用1×1卷积层降维至给定通道数,得到最终输出。
可以看到,ASPP本质由一个1×1的卷积 (最左侧绿色)+ 池化金字塔(中间三个蓝色) + ASPPPooling(最右侧三层)组成。而池化金字塔各层的膨胀因子是可以自定的,从而实现自由的多尺度特征提取。
接再来对torchvision中的ASPP源代码(deeplabv3.py)进行简单解读。
3.ASPPConv
方法头:
def __init__(self, in_channels, out_channels, dilation):
- in_channels:输入通道数
- out_channels:输出通道数
- dilation:膨胀率
事实上空洞卷积层与一般卷积层之间的差别就在于这个膨胀率,其控制的是卷积时的padding以及dilation。通过不同的填充以及与膨胀,可以获取不同尺度的感受野,提取多尺度的信息。注意卷积核尺寸始终是保持3×3不变的。
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False)
完整源码如下:
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
4.ASPPPolling
ASPPPolling首先是一个AdaptiveAvgPool2d层。
所谓自适应均值池化,其自适应的地方在于不需要指定kernel size和stride,只需要指定最后的输出尺寸(这里为1×1)。通过将各通道的特征图分别压缩至1×1,从而提取各通道的特征,进而获取全局的特征:
nn.AdaptiveAvgPool2d(1)
然后是一个1×1的卷积层,对上一步获取的特征进行进一步的提取,并降维:
nn.Conv2d(in_channels, out_channels, 1, bias=False),
完整源码如下:
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
需要注意的是,在ASPPPolliing的网络结构部分,只是对特征进行了提取;而在forward方法中,除了顺序执行网络的各层外,最终还将特征图从1×1上采样回原来的尺寸。
5.ASPP
方法头:
def __init__(self, in_channels, atrous_rates, out_channels=256)
- in_channels:输入通道数
- atrous_rates:膨胀因子
- out_channels:输出通道数,默认为256
1. 最开始是一个1×1的卷积层,进行降维:
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
2. 构建"池化金字塔"。对于给定的膨胀因子atrous_rates,叠加相应的空洞卷积层,提取不同尺度下的特征:
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
3. 添加空洞池化层:
modules.append(ASPPPooling(in_channels, out_channels))
4. 输出层,用于对ASPP各层叠加后的输出,进行卷积操作,得到最终结果:
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
☆ forward
对于forward方法,其顺序执行ASPP的各层,将各层的输出按通道叠加,并通过输出层的conv->bn->relu->dropout降维至给定通道数,获取最终结果:
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
# (B, C, H, W), dim = 1, 按通道拼接
res = torch.cat(res, dim=1)
return self.project(res)
6.完整源码
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256):
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)