Bootstrap

经典CNN模型(五):ResNet(PyTorch详细注释版)

一. ResNet 神经网络介绍

ResNet,即残差神经网络(Residual Neural Network),是由何凯明(Kaiming He)、张祥雨、任少卿和孙剑在 2015 年提出的一种深度学习架构,旨在解决深度神经网络中的退化问题。在传统的深度神经网络中,随着网络层数的增加,训练误差会逐渐增大,这种现象被称为“退化”。ResNet 通过引入一种特殊的残差学习框架来解决这一问题,允许网络学习更深层次的结构,而不遭受退化问题的影响。

二. 概念拓展

1. 残差块

ResNet 的核心创新是残差块(Residual Block)。一个标准的残差块包含两个或更多的卷积层,以及一个从输入直接连接到块输出的“捷径连接”(Shortcut Connection)。这个捷径连接允许输入信号直接传递到块的输出端,与经过卷积层处理后的信号相加。这样,残差块的目标就变成了学习一个残差函数,即输入到输出之间的差异,而不是整个映射函数。
在ResNet中,残差块是构建整个网络的基础单元。残差块主要由两部分构成:一个或多个卷积层和一个捷径连接(直连边,Shortcut Connection)。下面我将详细解释这两种类型的残差块以及它们的工作原理:

1.1 标准残差块(见图左

在标准残差块中,输入 x x x 通过一系列卷积层(通常是两个3x3的卷积层)进行特征变换,得到输出 F ( x ) F(x) F(x)。与此同时,输入 x x x 直接通过直连边传递,然后与 F ( x ) F(x) F(x) 相加得到最终的输出 y y y。即: y = F ( x , W i ) + x y = F(x, W_i) + x y=F(x,Wi)+x这里的 W i W_i Wi 是残差块中卷积层的权重参数。这种结构只有在输入和输出的维度相同的情况下才能成立,因为只有维度相同,输入和输出才能直接相加。

1.2 降采样残差块(见图右

然而,在网络的某些层,我们可能需要减小特征图的尺寸或者改变通道数,这时输入和输出的维度就不一致了。为了解决这个问题,ResNet 引入了降采样残差块,它在直连边上增加了一个 1x1 的卷积层,用于调整输入的维度,使之与输出维度匹配,以便进行相加操作。降采样残差块的公式可以表示为: y = F ( x , W i ) + W s ( x ) y = F(x, W_i) + W_s(x) y=F(x,Wi)+Ws(x)这里的 W s W_s Ws 表示 1x1 卷积层的权重参数,它负责对输入 x x x 进行降维或升维,以匹配 F ( x , W i ) F(x, W_i) F(x,Wi) 的维度。

1.3 直连边的作用

直连边在残差块中的作用是至关重要的,它保证了即使当 F ( x ) F(x) F(x) 接近零时,网络仍然能学习到恒等映射,即 y ≈ x y \approx x yx。这样,即使网络变得非常深,也不至于出现梯度消失或梯度爆炸的问题,从而能够训练出更深的网络。

1.4 端到端训练

ResNet 的设计使得整个网络可以进行端到端的反向传播训练。通过残差块的这种结构,即使在网络非常深的情况下,反向传播算法也可以有效地更新所有层的权重,因为每一步的梯度计算都不会因为深度的增加而显著衰减。

2. BN层(Batch Normalization)

ResNet 在卷积层之后使用了批量归一化(Batch Normalization,BN)层,这有助于稳定和加速训练过程。BN 层可以减少内部协变量移位,使得网络更容易训练。
Batch Normalization(BN,批量归一化)是深度学习中一项重要的技术,它在训练深度神经网络时能够显著加速收敛速度,并提高模型的泛化能力。BN 层的主要原理和作用如下:

2.1 原理
  1. 归一化输入:在每一层的前向传播过程中,BN 层对输入的每个 mini-batch 进行归一化,使得该批次数据的分布具有零均值和单位方差。具体来说,对于一个mini-batch 中的每个特征维度,BN 计算该批次的均值和方差,然后使用这些统计量将特征值归一化。

  2. 缩放和平移:BN 层在归一化后,使用可学习的参数 γ(缩放因子)和 β(偏移量)对数据进行缩放和平移,以恢复或调整数据的分布。这两个参数在训练过程中会被优化,以适应网络的学习需求。

2.2 作用
  1. 加速训练:BN 层通过减少内部协变量移位(Internal Covariate Shift),即减轻了网络训练过程中输入分布的变化,从而加速了训练速度。这是因为每一层的输入分布保持相对稳定,网络不需要花费额外的精力去适应不断变化的输入分布。

  2. 允许使用较高学习率:由于 BN 层能够稳定训练过程,因此可以使用较高的学习率,这进一步加快了收敛速度。

  3. 提高泛化能力:BN 层具有轻微的正则化效果,这有助于提高模型的泛化能力,减少过拟合。

  4. 简化初始化:BN 层的存在降低了对网络权重初始化的要求,因为 BN 层能够自动调整输入的分布,使得权重的初始化不再那么敏感。

  5. 替代Dropout:在某些情况下,BN 层的效果可以替代 Dropout 层,因为 BN 层本身就有一定的正则化作用,从而可以减少或消除 Dropout 的使用。

  6. 打乱样本训练顺序:BN 层可以视为在一定程度上打乱了样本的训练顺序,因为每个 mini-batch 的样本被混合在一起进行归一化,这有助于提高模型的准确性。

2.3 BN 层计算过程

BN 层需要计算一个 minibatch input feature( x i x_i xi)中所有元素的均值 μ μ μ 和方差 σ σ σ,然后对 x i x_i xi 减去均值除以标准差,最后利用可学习参数 γ γ γ β β β 进行仿射变换,即可得到最终的 BN 输出 y i y_i yi
μ B ← 1 m ∑ i = 1 m x i σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 x ^ i ← x i − μ B σ B 2 + ϵ y i ← γ x ^ i + β ≡ B N γ , β ( x i ) \begin{array}{l} \mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^{m} x_{i} \\ \sigma_{\mathcal{B}}^{2} \leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i}-\mu_{\mathcal{B}}\right)^{2} \\ \widehat{x}_{i} \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}} \\ y_{i} \leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{BN}_{\gamma, \beta}\left(x_{i}\right) \end{array} μBm1i=1mxiσB2m1i=1m(xiμB)2x iσB2+ϵ xiμByiγx i+βBNγ,β(xi)
具体过程为:

  • 1.计算样本均值。
  • 2.计算样本方差。
  • 3.样本数据标准化处理。
  • 4.进行平移和缩放处理。

引入了 γ \gamma γ β \beta β 两个参数。来训练 γ γ γ β β β 两个参数。引入了这个可学习重构参数 γ \gamma γ β \beta β,让我们的网络可以学习恢复出原始网络所要学习的特征分布。需要注意的是,既然 BN 是在 batch 维度上做 normalization,那么 BN 涉及的几个变量的维度就应该是一个 feature map的size,比如 BN layer 的输入是 N × C × W × H,那么 BN 的参数维度就应该是 C × W × H,也就是说 BN 计算的是 batch 内 feature map 每个点的均值,方差,以及 γ \gamma γ β \beta β

三. ResNet 神经网络结构

ResNet 各配置神经网络结构如下,接下来以 ResNet18 为例,简要介绍网络结构。
在这里插入图片描述
在这里插入图片描述
ResNet18 是一种轻量级的残差神经网络,它拥有 18 个可训练的卷积层,加上输入层和输出层,一共构成 18 层的深度。ResNet18 是在计算机视觉任务中非常流行的一种模型,尤其是在资源有限的情况下,它能够提供较好的性能。下面是 ResNet18的详细结构:

输入层

  • 卷积层:ResNet18 开始于一个 7x7 的卷积层,步幅为 2,通常用于处理 3 通道的 RGB 图像输入。此层输出的特征图大小为输入图像的一半。

最大池化层

  • Max Pooling:紧接着卷积层的是一个 3x3 的最大池化层,步幅同样为 2,这进一步减小了特征图的尺寸。

主体部分

ResNet18 的主体部分由四个阶段组成,每个阶段包含一组残差块。这些阶段通常称为layer1、layer2、layer3和layer4,每个阶段的残差块数量和卷积核的数量有所不同。

Layer1
  • 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 64。在第一个残差块中,输入和输出的维度相同,因此可以使用简单的捷径连接(Identity Shortcut)。在第二个残差块中,输入和输出的维度相同,故同样使用 Identity Shortcut。
Layer2
  • 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 128。在第一个残差块中,因为要改变特征图的尺寸和通道数,所以捷径连接会使用 1x1 的卷积层来匹配输出的维度。第二个残差块使用 Identity Shortcut。
Layer3
  • 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 256。类似于 Layer2,第一个残差块会使用 1x1 的卷积层来调整捷径连接的维度,而第二个残差块则使用 Identity Shortcut。
Layer4
  • 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 512。同样地,第一个残差块需要使用 1x1 的卷积层来匹配捷径连接的输出维度,而第二个残差块使用 Identity Shortcut。

输出层

  • 全局平均池化:在最后一个残差块之后,使用全局平均池化(Global Average Pooling,GAP)层将特征图转换为固定长度的向量,通常为 512 维。

  • 全连接层:GAP 层的输出被馈送到一个全连接层,其节点数等于分类任务的类别数。例如,对于 ImageNet 数据集,该层的输出节点数为 1000。

  • Softmax层:全连接层的输出被送入 Softmax 层,产生每个类别的概率预测。

四. ResNet 模型亮点

ResNet(残差神经网络)是深度学习领域的一项重要创新,其设计亮点和贡献主要包括以下几个方面:

  1. 残差学习
    ResNet 最核心的创新是引入了残差学习框架,它允许网络学习残差函数而非完整的输入到输出的映射。这通过添加捷径连接(shortcut connections)实现,使得网络能够直接将输入传递到几层之后,与中间层的输出相加。这种设计使得网络可以训练更深的层次,而不会遇到梯度消失或梯度爆炸的问题。

  2. 深度可扩展性
    通过残差块的设计,ResNet 能够有效地扩展至非常深的层次,例如 ResNet-152 有超过 150 层,这是在 ResNet 提出之前难以实现的。这表明即使在极深的网络中,ResNet 也能保持良好的性能和训练稳定性。

  3. 批规范化(Batch Normalization)
    ResNet 在卷积层之后使用了批规范化(Batch Normalization,BN),这有助于加速训练过程并减少内部协变量移位。BN 层在 ResNet 中扮演了关键角色,它确保了每一层的输入分布保持稳定,从而简化了训练过程。

  4. 避免过拟合
    ResNet 通过残差块的简洁设计,实际上减少了模型的复杂度,因为网络不需要学习复杂的输入输出映射,而是专注于学习输入和输出之间的差异。这有助于减少过拟合的风险。

  5. 通用性和灵活性
    ResNet 的模块化设计使得它非常灵活,容易扩展和适应不同的任务和数据集。残差块可以轻松堆叠,以创建不同深度的网络,同时还可以通过改变卷积核的大小和数量来调整网络的宽度。

  6. 性能提升
    ResNet 在多项计算机视觉任务上表现出色,特别是在大规模数据集如 ImageNet 上的图像分类任务中,它达到了当时最高的准确率,同时也展示了良好的泛化能力。

  7. 启发后续研究
    ResNet 的成功启发了后续许多深度学习模型的创新,包括更深的网络结构、更高效的残差块变体以及其他类型的捷径连接,如密集连接(DenseNet)和金字塔网络(PyramidNet)等。

总之,ResNet 通过其独特的残差学习框架和模块化设计,不仅解决了深度神经网络训练的难题,还极大地推动了深度学习领域的发展,成为深度学习研究中的一个标志性成果。

五. ResNet 代码实现

开发环境配置说明:本项目使用 Python 3.6.13 和 PyTorch 1.10.2 构建,适用于CPU环境。

  • model.py:定义网络模型
  • train.py:加载数据集并训练,计算 loss 和 accuracy,保存训练好的网络参数
  • predict.py:用自己的数据集进行分类测试
  1. model.py
import torch.nn as nn
import torch

#   定义18层网络和34层网络的残差结构
class BasicBlock(nn.Module):
    #   expansion对应残差结构中,主分支的卷积核数有没有发生变化
    #   18层和34层的网络没有变化,50层、101层和152层的网络发生变化
    expansion = 1

    #   downsample下采样参数,用于残差分支的尺寸维度缩放
    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        #   BN层
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        #   下采样方法
        self.downsample = downsample

    def forward(self, x):
        #   分支线上的输出
        #   将x赋值给identity,捷径上不执行下采样的输出值
        identity = x
        #   判断downsample=None,对捷径执行下采样操作并输出
        if self.downsample is not None:
            identity = self.downsample(x)

        #   主支线上的输出
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        #   主分支输出加上捷径分支输出
        out += identity
        out = self.relu(out)

        return out


#   定义50层网络、101层网络和152层网络的残差结构
class Bottleneck(nn.Module):
    #   expansion对应残差结构中,主分支的卷积核数有没有发生变化
    #   50层、101层和152层的网络发生变化,其残差结构中第三层的卷积核个数为前两层的四倍,例如64—64—256
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.dowmsample = downsample

    def forward(self, x):
        identity = x
        if self.dowmsample is not None:
            identity = self.dowmsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

#   定义ResNet网络
class ResNet(nn.Module):
    #   block:残差结构 block_num(list):残差结构数量 include_top=True:方便在ResNet上搭建其他网络
    def __init__(self, block, block_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, block_num[0])
        self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)

        #   输出层+全连接层
        if self.include_top:
            self.avepool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        #   对卷积层初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    #   定义生成残差结构的方法
    #   block:残差结构 channel:第一层卷积核的个数 block_num:残差结构数量
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        #   判断通道数是否发生变化,来执行下采样操作
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        #   添加第一层残差结构
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        #   根据expansion来生成实线和虚线的残差结构
        self.in_channel = channel * block.expansion

        #   残差结构中除了第一层均为实线结构,将其依次添加到layers中
        for _ in range(1, block_num):  #    从1开始,即实线残差结构从第二层开始
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    #   正向传播过程
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avepool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x

def resnet18(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

def resnet50(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

def resnet152(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, include_top=include_top)
  1. train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from model import resnet34
import os
import json
import torchvision.models.resnet


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)

data_transform = {
    "train" : transforms.Compose([transforms.RandomResizedCrop(224),   # 随机裁剪
                                  transforms.RandomHorizontalFlip(),   # 随机翻转
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val" : transforms.Compose([transforms.Resize(256),      # 长宽比不变,最小边长缩放到256
                                transforms.CenterCrop(224),  # 中心裁剪到 224x224
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

#   获取数据集所在的根目录
#   通过os.getcwd()获取当前的目录,并将当前目录与".."链接获取上一层目录
data_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

#   获取花类数据集路径
image_path = data_root + "/data_set/flower_data/"

#   加载数据集
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])

#   获取训练集图像数量
train_num = len(train_dataset)

#   获取分类的名称
#   {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx

#   采用遍历方法,将分类名称的key与value反过来
cla_dict = dict((val, key) for key, val in flower_list.items())

#   将字典cla_dict编码为json格式
json_str = json.dumps(cla_dict, indent=4)
with open("class_indices.json", "w") as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = DataLoader(validate_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0)

#   定义模型
net = resnet34()   # 实例化模型
net.to(device)
model_weight_path = "./resnet34-pre.pth"
#   载入模型权重
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)

#   定义输入特征矩阵的深度
inchannel = net.fc.in_features
#   重新赋值全连接层
net.fc = nn.Linear(inchannel, 5)

loss_function = nn.CrossEntropyLoss()   # 定义损失函数
#pata = list(net.parameters())   # 查看模型参数
optimizer = optim.Adam(net.parameters(), lr=0.0001)  # 定义优化器

#   设置存储权重路径
save_path = './resNet34.pth'
best_acc = 0.0
for epoch in range(1):
    # train
    net.train()  # 用来管理Dropout方法:训练时使用Dropout方法,验证时不使用Dropout方法
    running_loss = 0.0  # 用来累加训练中的损失
    for step, data in enumerate(train_loader, start=0):
        #   获取数据的图像和标签
        images, labels = data

        #   将历史损失梯度清零
        optimizer.zero_grad()

        #   参数更新
        outputs = net(images.to(device))                   # 获得网络输出
        loss = loss_function(outputs, labels.to(device))   # 计算loss
        loss.backward()                                    # 误差反向传播
        optimizer.step()                                   # 更新节点参数

        #   打印统计信息
        running_loss += loss.item()
        #   打印训练进度
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()

    # validate
    net.eval()  # 关闭Dropout方法
    acc = 0.0
    #   验证过程中不计算损失梯度
    with torch.no_grad():
        for data_test in validate_loader:
            test_images, test_labels = data_test
            outputs = net(test_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            #   acc用来累计验证集中预测正确的数量
            #   对比预测值与真实标签,sum()求出预测正确的累加值,item()获取累加值
            acc += (predict_y == test_labels.to(device)).sum().item()
        accurate_test = acc / val_num
        #   如果当前准确率大于历史最优准确率
        if accurate_test > best_acc:
            #   更新历史最优准确率
            best_acc = accurate_test
            #   保存当前权重
            torch.save(net.state_dict(), save_path)
        #   打印相应信息
        print("[epoch %d] train_loss: %.3f  test_accuracy: %.3f"%
              (epoch + 1, running_loss / step, acc / val_num))

print("Finished Training")
  1. predict.py
import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.RandomResizedCrop(224),   # 随机裁剪
     transforms.RandomHorizontalFlip(),   # 随机翻转
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#   载入图像
img = Image.open("./郁金香.png")

#   [N, C, H, W]
#   图像预处理
img = data_transform(img)

#   增加 batch 维度
img = torch.unsqueeze(img, dim=0)

# 读取 class_indict
try:
    json_file = open("./class_indices.json", "r")
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

#   创建模型
model = resnet34(num_classes=5)
#   加载权重
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
#   屏蔽Dropout
model.eval()

with torch.no_grad():
    #   model(img)将图像输入模型得到输出,采用squeeze压缩维度,即将Batch维度压缩掉
    output = torch.squeeze(model(img))
    #   采用softmax将最终输出转化为概率分布
    predict = torch.softmax(output, dim=0)
    #   获取概率最大处的索引值
    predict_cla = torch.argmax(predict).numpy()
#   打印类别名称及其对应的预测概率
print(class_indict[str(predict_cla)], predict[predict_cla].item())

六. 参考内容

  1. 李沐. (2019). 动手学深度学习. 北京: 人民邮电出版社. [ISBN: 978-7-115-51364-9]
  2. 霹雳吧啦Wz. (202X). 深度学习实战系列 [在线视频]. 哔哩哔哩. URL
  3. PyTorch. (n.d.). PyTorch官方文档和案例 [在线资源]. URL
;