一. ResNet 神经网络介绍
ResNet,即残差神经网络(Residual Neural Network),是由何凯明(Kaiming He)、张祥雨、任少卿和孙剑在 2015 年提出的一种深度学习架构,旨在解决深度神经网络中的退化问题。在传统的深度神经网络中,随着网络层数的增加,训练误差会逐渐增大,这种现象被称为“退化”。ResNet 通过引入一种特殊的残差学习框架来解决这一问题,允许网络学习更深层次的结构,而不遭受退化问题的影响。
二. 概念拓展
1. 残差块
ResNet 的核心创新是残差块(Residual Block)。一个标准的残差块包含两个或更多的卷积层,以及一个从输入直接连接到块输出的“捷径连接”(Shortcut Connection)。这个捷径连接允许输入信号直接传递到块的输出端,与经过卷积层处理后的信号相加。这样,残差块的目标就变成了学习一个残差函数,即输入到输出之间的差异,而不是整个映射函数。
在ResNet中,残差块是构建整个网络的基础单元。残差块主要由两部分构成:一个或多个卷积层和一个捷径连接(直连边,Shortcut Connection)。下面我将详细解释这两种类型的残差块以及它们的工作原理:
1.1 标准残差块(见图左)
在标准残差块中,输入 x x x 通过一系列卷积层(通常是两个3x3的卷积层)进行特征变换,得到输出 F ( x ) F(x) F(x)。与此同时,输入 x x x 直接通过直连边传递,然后与 F ( x ) F(x) F(x) 相加得到最终的输出 y y y。即: y = F ( x , W i ) + x y = F(x, W_i) + x y=F(x,Wi)+x这里的 W i W_i Wi 是残差块中卷积层的权重参数。这种结构只有在输入和输出的维度相同的情况下才能成立,因为只有维度相同,输入和输出才能直接相加。
1.2 降采样残差块(见图右)
然而,在网络的某些层,我们可能需要减小特征图的尺寸或者改变通道数,这时输入和输出的维度就不一致了。为了解决这个问题,ResNet 引入了降采样残差块,它在直连边上增加了一个 1x1 的卷积层,用于调整输入的维度,使之与输出维度匹配,以便进行相加操作。降采样残差块的公式可以表示为: y = F ( x , W i ) + W s ( x ) y = F(x, W_i) + W_s(x) y=F(x,Wi)+Ws(x)这里的 W s W_s Ws 表示 1x1 卷积层的权重参数,它负责对输入 x x x 进行降维或升维,以匹配 F ( x , W i ) F(x, W_i) F(x,Wi) 的维度。
1.3 直连边的作用
直连边在残差块中的作用是至关重要的,它保证了即使当 F ( x ) F(x) F(x) 接近零时,网络仍然能学习到恒等映射,即 y ≈ x y \approx x y≈x。这样,即使网络变得非常深,也不至于出现梯度消失或梯度爆炸的问题,从而能够训练出更深的网络。
1.4 端到端训练
ResNet 的设计使得整个网络可以进行端到端的反向传播训练。通过残差块的这种结构,即使在网络非常深的情况下,反向传播算法也可以有效地更新所有层的权重,因为每一步的梯度计算都不会因为深度的增加而显著衰减。
2. BN层(Batch Normalization)
ResNet 在卷积层之后使用了批量归一化(Batch Normalization,BN)层,这有助于稳定和加速训练过程。BN 层可以减少内部协变量移位,使得网络更容易训练。
Batch Normalization(BN,批量归一化)是深度学习中一项重要的技术,它在训练深度神经网络时能够显著加速收敛速度,并提高模型的泛化能力。BN 层的主要原理和作用如下:
2.1 原理
-
归一化输入:在每一层的前向传播过程中,BN 层对输入的每个 mini-batch 进行归一化,使得该批次数据的分布具有零均值和单位方差。具体来说,对于一个mini-batch 中的每个特征维度,BN 计算该批次的均值和方差,然后使用这些统计量将特征值归一化。
-
缩放和平移:BN 层在归一化后,使用可学习的参数 γ(缩放因子)和 β(偏移量)对数据进行缩放和平移,以恢复或调整数据的分布。这两个参数在训练过程中会被优化,以适应网络的学习需求。
2.2 作用
-
加速训练:BN 层通过减少内部协变量移位(Internal Covariate Shift),即减轻了网络训练过程中输入分布的变化,从而加速了训练速度。这是因为每一层的输入分布保持相对稳定,网络不需要花费额外的精力去适应不断变化的输入分布。
-
允许使用较高学习率:由于 BN 层能够稳定训练过程,因此可以使用较高的学习率,这进一步加快了收敛速度。
-
提高泛化能力:BN 层具有轻微的正则化效果,这有助于提高模型的泛化能力,减少过拟合。
-
简化初始化:BN 层的存在降低了对网络权重初始化的要求,因为 BN 层能够自动调整输入的分布,使得权重的初始化不再那么敏感。
-
替代Dropout:在某些情况下,BN 层的效果可以替代 Dropout 层,因为 BN 层本身就有一定的正则化作用,从而可以减少或消除 Dropout 的使用。
-
打乱样本训练顺序:BN 层可以视为在一定程度上打乱了样本的训练顺序,因为每个 mini-batch 的样本被混合在一起进行归一化,这有助于提高模型的准确性。
2.3 BN 层计算过程
BN 层需要计算一个 minibatch input feature(
x
i
x_i
xi)中所有元素的均值
μ
μ
μ 和方差
σ
σ
σ,然后对
x
i
x_i
xi 减去均值除以标准差,最后利用可学习参数
γ
γ
γ 和
β
β
β 进行仿射变换,即可得到最终的 BN 输出
y
i
y_i
yi
μ
B
←
1
m
∑
i
=
1
m
x
i
σ
B
2
←
1
m
∑
i
=
1
m
(
x
i
−
μ
B
)
2
x
^
i
←
x
i
−
μ
B
σ
B
2
+
ϵ
y
i
←
γ
x
^
i
+
β
≡
B
N
γ
,
β
(
x
i
)
\begin{array}{l} \mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^{m} x_{i} \\ \sigma_{\mathcal{B}}^{2} \leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i}-\mu_{\mathcal{B}}\right)^{2} \\ \widehat{x}_{i} \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}} \\ y_{i} \leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{BN}_{\gamma, \beta}\left(x_{i}\right) \end{array}
μB←m1∑i=1mxiσB2←m1∑i=1m(xi−μB)2x
i←σB2+ϵxi−μByi←γx
i+β≡BNγ,β(xi)
具体过程为:
- 1.计算样本均值。
- 2.计算样本方差。
- 3.样本数据标准化处理。
- 4.进行平移和缩放处理。
引入了 γ \gamma γ 和 β \beta β 两个参数。来训练 γ γ γ 和 β β β 两个参数。引入了这个可学习重构参数 γ \gamma γ 和 β \beta β,让我们的网络可以学习恢复出原始网络所要学习的特征分布。需要注意的是,既然 BN 是在 batch 维度上做 normalization,那么 BN 涉及的几个变量的维度就应该是一个 feature map的size,比如 BN layer 的输入是 N × C × W × H,那么 BN 的参数维度就应该是 C × W × H,也就是说 BN 计算的是 batch 内 feature map 每个点的均值,方差,以及 γ \gamma γ 和 β \beta β。
三. ResNet 神经网络结构
ResNet 各配置神经网络结构如下,接下来以 ResNet18 为例,简要介绍网络结构。
ResNet18 是一种轻量级的残差神经网络,它拥有 18 个可训练的卷积层,加上输入层和输出层,一共构成 18 层的深度。ResNet18 是在计算机视觉任务中非常流行的一种模型,尤其是在资源有限的情况下,它能够提供较好的性能。下面是 ResNet18的详细结构:
输入层
- 卷积层:ResNet18 开始于一个 7x7 的卷积层,步幅为 2,通常用于处理 3 通道的 RGB 图像输入。此层输出的特征图大小为输入图像的一半。
最大池化层
- Max Pooling:紧接着卷积层的是一个 3x3 的最大池化层,步幅同样为 2,这进一步减小了特征图的尺寸。
主体部分
ResNet18 的主体部分由四个阶段组成,每个阶段包含一组残差块。这些阶段通常称为layer1、layer2、layer3和layer4,每个阶段的残差块数量和卷积核的数量有所不同。
Layer1
- 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 64。在第一个残差块中,输入和输出的维度相同,因此可以使用简单的捷径连接(Identity Shortcut)。在第二个残差块中,输入和输出的维度相同,故同样使用 Identity Shortcut。
Layer2
- 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 128。在第一个残差块中,因为要改变特征图的尺寸和通道数,所以捷径连接会使用 1x1 的卷积层来匹配输出的维度。第二个残差块使用 Identity Shortcut。
Layer3
- 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 256。类似于 Layer2,第一个残差块会使用 1x1 的卷积层来调整捷径连接的维度,而第二个残差块则使用 Identity Shortcut。
Layer4
- 两个残差块:每个残差块包含两个 3x3 的卷积层,输出通道数为 512。同样地,第一个残差块需要使用 1x1 的卷积层来匹配捷径连接的输出维度,而第二个残差块使用 Identity Shortcut。
输出层
-
全局平均池化:在最后一个残差块之后,使用全局平均池化(Global Average Pooling,GAP)层将特征图转换为固定长度的向量,通常为 512 维。
-
全连接层:GAP 层的输出被馈送到一个全连接层,其节点数等于分类任务的类别数。例如,对于 ImageNet 数据集,该层的输出节点数为 1000。
-
Softmax层:全连接层的输出被送入 Softmax 层,产生每个类别的概率预测。
四. ResNet 模型亮点
ResNet(残差神经网络)是深度学习领域的一项重要创新,其设计亮点和贡献主要包括以下几个方面:
-
残差学习:
ResNet 最核心的创新是引入了残差学习框架,它允许网络学习残差函数而非完整的输入到输出的映射。这通过添加捷径连接(shortcut connections)实现,使得网络能够直接将输入传递到几层之后,与中间层的输出相加。这种设计使得网络可以训练更深的层次,而不会遇到梯度消失或梯度爆炸的问题。 -
深度可扩展性:
通过残差块的设计,ResNet 能够有效地扩展至非常深的层次,例如 ResNet-152 有超过 150 层,这是在 ResNet 提出之前难以实现的。这表明即使在极深的网络中,ResNet 也能保持良好的性能和训练稳定性。 -
批规范化(Batch Normalization):
ResNet 在卷积层之后使用了批规范化(Batch Normalization,BN),这有助于加速训练过程并减少内部协变量移位。BN 层在 ResNet 中扮演了关键角色,它确保了每一层的输入分布保持稳定,从而简化了训练过程。 -
避免过拟合:
ResNet 通过残差块的简洁设计,实际上减少了模型的复杂度,因为网络不需要学习复杂的输入输出映射,而是专注于学习输入和输出之间的差异。这有助于减少过拟合的风险。 -
通用性和灵活性:
ResNet 的模块化设计使得它非常灵活,容易扩展和适应不同的任务和数据集。残差块可以轻松堆叠,以创建不同深度的网络,同时还可以通过改变卷积核的大小和数量来调整网络的宽度。 -
性能提升:
ResNet 在多项计算机视觉任务上表现出色,特别是在大规模数据集如 ImageNet 上的图像分类任务中,它达到了当时最高的准确率,同时也展示了良好的泛化能力。 -
启发后续研究:
ResNet 的成功启发了后续许多深度学习模型的创新,包括更深的网络结构、更高效的残差块变体以及其他类型的捷径连接,如密集连接(DenseNet)和金字塔网络(PyramidNet)等。
总之,ResNet 通过其独特的残差学习框架和模块化设计,不仅解决了深度神经网络训练的难题,还极大地推动了深度学习领域的发展,成为深度学习研究中的一个标志性成果。
五. ResNet 代码实现
开发环境配置说明:本项目使用 Python 3.6.13 和 PyTorch 1.10.2 构建,适用于CPU环境。
- model.py:定义网络模型
- train.py:加载数据集并训练,计算 loss 和 accuracy,保存训练好的网络参数
- predict.py:用自己的数据集进行分类测试
- model.py
import torch.nn as nn
import torch
# 定义18层网络和34层网络的残差结构
class BasicBlock(nn.Module):
# expansion对应残差结构中,主分支的卷积核数有没有发生变化
# 18层和34层的网络没有变化,50层、101层和152层的网络发生变化
expansion = 1
# downsample下采样参数,用于残差分支的尺寸维度缩放
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
# BN层
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
# 下采样方法
self.downsample = downsample
def forward(self, x):
# 分支线上的输出
# 将x赋值给identity,捷径上不执行下采样的输出值
identity = x
# 判断downsample=None,对捷径执行下采样操作并输出
if self.downsample is not None:
identity = self.downsample(x)
# 主支线上的输出
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# 主分支输出加上捷径分支输出
out += identity
out = self.relu(out)
return out
# 定义50层网络、101层网络和152层网络的残差结构
class Bottleneck(nn.Module):
# expansion对应残差结构中,主分支的卷积核数有没有发生变化
# 50层、101层和152层的网络发生变化,其残差结构中第三层的卷积核个数为前两层的四倍,例如64—64—256
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=1, stride=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.dowmsample = downsample
def forward(self, x):
identity = x
if self.dowmsample is not None:
identity = self.dowmsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
# 定义ResNet网络
class ResNet(nn.Module):
# block:残差结构 block_num(list):残差结构数量 include_top=True:方便在ResNet上搭建其他网络
def __init__(self, block, block_num, num_classes=1000, include_top=True):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, block_num[0])
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
# 输出层+全连接层
if self.include_top:
self.avepool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 对卷积层初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# 定义生成残差结构的方法
# block:残差结构 channel:第一层卷积核的个数 block_num:残差结构数量
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
# 判断通道数是否发生变化,来执行下采样操作
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
# 添加第一层残差结构
layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
# 根据expansion来生成实线和虚线的残差结构
self.in_channel = channel * block.expansion
# 残差结构中除了第一层均为实线结构,将其依次添加到layers中
for _ in range(1, block_num): # 从1开始,即实线残差结构从第二层开始
layers.append(block(self.in_channel, channel))
return nn.Sequential(*layers)
# 正向传播过程
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avepool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet18(num_classes=1000, include_top=True):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
def resnet34(num_classes=1000, include_top=True):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet50(num_classes=1000, include_top=True):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet101(num_classes=1000, include_top=True):
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
def resnet152(num_classes=1000, include_top=True):
return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, include_top=include_top)
- train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from model import resnet34
import os
import json
import torchvision.models.resnet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)
data_transform = {
"train" : transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机翻转
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val" : transforms.Compose([transforms.Resize(256), # 长宽比不变,最小边长缩放到256
transforms.CenterCrop(224), # 中心裁剪到 224x224
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 获取数据集所在的根目录
# 通过os.getcwd()获取当前的目录,并将当前目录与".."链接获取上一层目录
data_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
# 获取花类数据集路径
image_path = data_root + "/data_set/flower_data/"
# 加载数据集
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"])
# 获取训练集图像数量
train_num = len(train_dataset)
# 获取分类的名称
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
# 采用遍历方法,将分类名称的key与value反过来
cla_dict = dict((val, key) for key, val in flower_list.items())
# 将字典cla_dict编码为json格式
json_str = json.dumps(cla_dict, indent=4)
with open("class_indices.json", "w") as json_file:
json_file.write(json_str)
batch_size = 16
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = DataLoader(validate_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0)
# 定义模型
net = resnet34() # 实例化模型
net.to(device)
model_weight_path = "./resnet34-pre.pth"
# 载入模型权重
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)
# 定义输入特征矩阵的深度
inchannel = net.fc.in_features
# 重新赋值全连接层
net.fc = nn.Linear(inchannel, 5)
loss_function = nn.CrossEntropyLoss() # 定义损失函数
#pata = list(net.parameters()) # 查看模型参数
optimizer = optim.Adam(net.parameters(), lr=0.0001) # 定义优化器
# 设置存储权重路径
save_path = './resNet34.pth'
best_acc = 0.0
for epoch in range(1):
# train
net.train() # 用来管理Dropout方法:训练时使用Dropout方法,验证时不使用Dropout方法
running_loss = 0.0 # 用来累加训练中的损失
for step, data in enumerate(train_loader, start=0):
# 获取数据的图像和标签
images, labels = data
# 将历史损失梯度清零
optimizer.zero_grad()
# 参数更新
outputs = net(images.to(device)) # 获得网络输出
loss = loss_function(outputs, labels.to(device)) # 计算loss
loss.backward() # 误差反向传播
optimizer.step() # 更新节点参数
# 打印统计信息
running_loss += loss.item()
# 打印训练进度
rate = (step + 1) / len(train_loader)
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
print()
# validate
net.eval() # 关闭Dropout方法
acc = 0.0
# 验证过程中不计算损失梯度
with torch.no_grad():
for data_test in validate_loader:
test_images, test_labels = data_test
outputs = net(test_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
# acc用来累计验证集中预测正确的数量
# 对比预测值与真实标签,sum()求出预测正确的累加值,item()获取累加值
acc += (predict_y == test_labels.to(device)).sum().item()
accurate_test = acc / val_num
# 如果当前准确率大于历史最优准确率
if accurate_test > best_acc:
# 更新历史最优准确率
best_acc = accurate_test
# 保存当前权重
torch.save(net.state_dict(), save_path)
# 打印相应信息
print("[epoch %d] train_loss: %.3f test_accuracy: %.3f"%
(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")
- predict.py
import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
data_transform = transforms.Compose(
[transforms.RandomResizedCrop(224), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机翻转
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 载入图像
img = Image.open("./郁金香.png")
# [N, C, H, W]
# 图像预处理
img = data_transform(img)
# 增加 batch 维度
img = torch.unsqueeze(img, dim=0)
# 读取 class_indict
try:
json_file = open("./class_indices.json", "r")
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# 创建模型
model = resnet34(num_classes=5)
# 加载权重
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
# 屏蔽Dropout
model.eval()
with torch.no_grad():
# model(img)将图像输入模型得到输出,采用squeeze压缩维度,即将Batch维度压缩掉
output = torch.squeeze(model(img))
# 采用softmax将最终输出转化为概率分布
predict = torch.softmax(output, dim=0)
# 获取概率最大处的索引值
predict_cla = torch.argmax(predict).numpy()
# 打印类别名称及其对应的预测概率
print(class_indict[str(predict_cla)], predict[predict_cla].item())