1. GoogLeNet 浅析
GoogLeNet 于论文《Going deeper with convolutions》中提出,并一举斩获 2014 年 ImageNet 挑战赛的冠军。一般而言,增加网络深度和宽度是提升网络性能最直接的方法,但这样也会带来诸多问题:
- 参数太多,如果训练数据集有限,很容易产生过拟合;
- 网络越大,参数越多,计算复杂度越大,应用难度越高;
- 网络越深,越容易出现梯度弥散问题。
GoogLeNet 共有 22 层网络,但其参数量却比 AlexNet 和 VGG 小很多。GoogLeNet 主要通过以下方法来提升网络的性能:
- 通过 Inception 结构融合多尺度特征信息;
- 使用 1x1 卷积进行降维,减少计算量;
- 添加辅助分类器,缓解梯度弥散问题,帮助训练;
- 丢弃全连接层,使用平均池化层,大幅减少模型参数量。
GoogLeNet 网络结构
GoogLeNet 的网络结构如下,共有 22 层(蓝色部分)。关于 GoogLeNet 网络的分析,主要关注在两个方面,分别是 Inception 结构(红框部分)和辅助分类器(绿框部分)。
Inception 结构
Inception 结构图如下,左图是 Inception 的原始结构,右图是加上降维功能的结构。Inception 结构能够提取不同尺度的特征,同时利用稀疏矩阵计算的原理来加速收敛。此外,论文作者认为池化也具有提取特征的功能,因此在第四个分支也使用了最大池化。
Inception 结构共有 4 个分支,输入特征经由各分支得到 4 个输出,之后在通道维度进行拼接得到最终输出。对于各分支而言,需要通过 stride 和 padding 来保证得到同样大小的输出。对比左图,右图在分支 2,3,4 上加入了 1x1 卷积以降维,在增加非线性表达能力的同时,减少参数量和计算量。下图给出 1x1 卷积降维的图示。
辅助分类器
辅助分类器主要用以缓解梯度弥散问题,GoogLeNet 网络中使用了两个辅助分类器,二者的结构是一模一样的,结构参数如下:
- 第一层:平均池化下采样层,池化核大小为 5x5,stride=3;
- 第二层:卷积层,卷积核大小为 1x1,stride=1,卷积核个数为 128;
- 第三层:全连接层,共1024 个节点;
- 第四层:全连接层,节点数为 1000,对应类别数。
2. 代码实现(PyTorch)
2.1 Inception 结构
具有降维功能的 Inception 结构如上图所示,其具体结构参数如下:
- 分支 1:是卷积核大小为 1x1 的卷积层,stride=1;
- 分支 2:是卷积核大小为 3x3 的卷积层,stride=1,padding=1;
- 分支 3:是卷积核大小为 5x5 的卷积层,stride=1,padding=2;
- 分支 4:是池化核大小为 3x3 的最大池化下采样,stride=1,padding=1。
import torch
import torch.nn as nn
class Inception(nn.Module):
def __init__(self, in_c, o_c):
super(Inception, self).__init__()
self.conv1 = nn.Conv2d(in_c, o_c, kernel_size=(1, 1))
self.branch2 = nn.Sequential(
nn.Conv2d(in_c, 4, kernel_size=(1, 1)),
nn.Conv2d(4, o_c, kernel_size=(3, 3), padding=1)
)
self.branch3 = nn.Sequential(
nn.Conv2d(in_c, 4, kernel_size=(1, 1)),
nn.Conv2d(4, o_c, kernel_size=(5, 5), padding=2)
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_c, 4, kernel_size=(1, 1))
)
def forward(self, x):
bran1 = self.conv1(x)
bran2 = self.branch2(x)
bran3 = self.branch3(x)
bran4 = self.branch4(x)
outputs = [bran1, bran2, bran3, bran4]
out = torch.cat(outputs, 1)
return out
if __name__ == "__main__":
x = torch.rand((8, 3, 32, 32)) # (B, C, H, W)
net = Inception(3, 8)
out = net(x)
2.2 GoogLeNet 网络
本人并未实现 GoogLeNet 网络,但在此处给出 github 地址以供有心人参考。
【参考】