一、前言
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊|接辅导、项目定制
- 难度:夯实基础⭐⭐
- 语言:Python3、Pytorch3
- 🍺要求:
1.根据本文的Tensorflow代码,编写Pytorch代码
2.了解残差网络
3.是否可以将残差模块融合到C3中
二、论文分析
论文:Deep Residual Learning for Image Recognition
问题的提出:
随着网络层数的增加,更深的网络具有更大的训练误差,从而导致测试误差。
所以提出了一个问题:对叠层数越多是不是训练网络效果越好呢?
这种问题的阻碍是梯度消失或者爆炸,而这种我们的解决办法是:初始化归一和中间层归一化
随着网络深度的增加,精度变得饱和,然后迅速退化,但是这种退化不是由于过度拟合引起的,这也就成为了模型训练退化问题。像适当深度的模型添加更多层会导致更高的训练误差。解决这种误差是这篇论文的主要目的。
解决方案一:添加的层是身份映射,其他层是从学习中较浅的模型复制,但是现有的解释器很难做
解决方案二:引入深度残差学习框架来解决这种退化问题。
将所需的基础映射表示为H(x)
让堆叠的非线性层适合F(x):= H(x)- x的另一个映射。
原始映射为F(x)+ x。
通过快捷连接来实现身份验证。
实验证明:
1)极深的残差网络易于优化,但是当深度增加时,对应的“普通”网络(简单地堆叠层)显示出更高的训练误差;
2)深层残差网络可以通过大大增加深度来轻松享受准确性的提高,所产生的结果比以前的网络要好得多。
Deep Residual Learning
残差学习:
将H(x)视为由一些堆叠层(不一定是整个网络)拟合的基础映射,其中x表示这些层中第一层的输入。如果假设多个非线性层可以渐近逼近复杂函数,那么就可以假设它们可以渐近逼近残差函数,即H(x)-x(假设输入和输出为尺寸相同)。因此,没有让堆叠的层近似为H(x),而是明确地让这些层近似为残差函数F(x):= H(x)-x。因此,原始函数变为F(x)+ x。尽管两种形式都应能够渐近地逼近所需的功能(如假设),但学习的难易程度可能有所不同。
简单来讲:
整个模块除了正常的卷积层输出外,还有一个分支把输入直接连在输出上,该分支输出和卷积的输出做算数相加得到了最终的输出,这种残差结构人为的制造了恒等映射,即F(x)分支中所有参数都是0,H(x)就是一个恒等映射,这样就能让整个结构朝着恒等映射的方向去收敛,确保最终的错误率不会因为深度的变大而越来越差。
假设我们现在已经有了一个N层的网络,现在在尾部加上K个残差模块(M层),
如果说这K个残差会造成网络过深,那么这K个残差模块会向恒等映射方向发展(参数为0),进而解决了网络过深问题
网络框架:
实验结果
可以明显看到在用ResNet之后,随着网络深度的增加,网络的训练效果更好。
三、残差网络(ResNet)介绍
1、残差网络解决了什么
残差网络是为了解决神经网络隐藏层过多时,而引起的网络退化问题。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和然后急剧退化,而且这个退化不是由于过拟合引起的。
拓展:深度神经网络的"两朵乌云"
- 梯度弥散/爆炸
简单来讲就是网络太深了,会导致模型训练难以收敛。这个问题可以被标准初始化和中间层正规化的方法有效控制。
- 网络退化
随着网络深度增加,网络的表现先是逐渐增加至饱和,然后迅速下降,这个退化不是由过拟合引起的。
2、ResNet-50介绍
ResNet-50有两个基本的块,分别名为Conv Block和Identity Block
Conv Block结构:和Identity Block结构:
ResNet-50总体结构:
四、构造ResNet-50模型
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device
运行结果:
device(type='cuda')
2. 导入数据
同时查看数据集中图片的数量
import pathlib
data_dir=r"D:\data\J-series\J1\bird_photos"
data_dir=pathlib.Path(data_dir)
image_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 565
3. 查看数据集分类
data_paths=list(data_dir.glob('*'))
classeNames=[str(path).split("\\")[5] for path in data_paths]
classeNames
运行结果:
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
4. 随机查看图片
随机抽取数据集中的20张图片进行查看
import random,PIL
import matplotlib.pyplot as plt
from PIL import Image
data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):
plt.subplot(2,10,i+1)
plt.axis('off')
image=random.choice(data_paths2) #随机选择一个图片
plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名
plt.imshow(Image.open(str(image))) #显示图片
运行结果:
5. 图片预处理
import torchvision.transforms as transforms
from torchvision import transforms,datasets
train_transforms=transforms.Compose([
transforms.Resize([224,224]), #将图片统一尺寸
transforms.RandomHorizontalFlip(), #将图片随机水平翻转
transforms.ToTensor(), #将图片转换为tensor
transforms.Normalize( #标准化处理—>转换为正态分布,使模型更容易收敛
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
test_transforms=transforms.Compose([
transforms.Resize([224,224]), #将图片统一尺寸
transforms.RandomHorizontalFlip(), #将图片随机水平翻转
transforms.ToTensor(), #将图片转换为tensor
transforms.Normalize( #标准化处理—>转换为正态分布,使模型更容易收敛
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
total_data=datasets.ImageFolder(
r"D:\THE MNIST DATABASE\J-series\J1\bird_photos",
transform=train_transforms
)
total_data
运行结果:
Dataset ImageFolder
Number of datapoints: 565
Root location: D:\THE MNIST DATABASE\J-series\J1\bird_photos
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
将数据集分类情况进行映射输出:
total_data.class_to_idx
运行结果:
{'Bananaquit': 0,'Black Skimmer': 1,'Black Throated Bushtiti': 2,'Cockatoo': 3}
6. 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(
total_data,
[train_size,test_size]
)
train_dataset,test_dataset
运行结果:
(<torch.utils.data.dataset.Subset at 0x2195b60dd50>,<torch.utils.data.dataset.Subset at 0x219508d5910>)
查看训练集和测试集的数据数量:
train_size,test_size
运行结果:
(452, 113)
7. 加载数据集
batch_size=8
train_dl=torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1
)
test_dl=torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1
)
查看测试集的情况:
for x,y in train_dl:
print("Shape of x [N,C,H,W]:",x.shape)
print("Shape of y:",y.shape,y.dtype)
break
运行结果:
Shape of x [N,C,H,W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
五、模型搭建
1、Tensorflow代码
def identity_block(input_ten,kernel_size,filters):
filters1,filters2,filters3 = filters
x = Conv2D(filters1,(1,1))(input_ten)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters2,kernel_size,padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters3,(1,1))(x)
x = BatchNormalization()(x)
x = layers.add([x,input_ten])
x = Activation('relu')(x)
return x
def conv_block(input_ten,kernel_size,filters,strides=(2,2)):
filters1,filters2,filters3 = filters
x = Conv2D(filters1,(1,1),strides=strides)(input_ten)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters2,kernel_size,padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters3,(1,1))(x)
x = BatchNormalization()(x)
shortcut = Conv2D(filters3,(1,1),strides=strides)(input_ten)
shortcut = BatchNormalization()(shortcut)
x = layers.add([x,shortcut])
x = Activation('relu')(x)
return x
def ResNet50(nb_class,input_shape):
input_ten = Input(shape=input_shape)
x = ZeroPadding2D((3,3))(input_ten)
x = Conv2D(64,(7,7),strides=(2,2))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((3,3),strides=(2,2))(x)
x = conv_block(x,3,[64,64,256],strides=(1,1))
x = identity_block(x,3,[64,64,256])
x = identity_block(x,3,[64,64,256])
x = conv_block(x,3,[128,128,512])
x = identity_block(x,3,[128,128,512])
x = identity_block(x,3,[128,128,512])
x = identity_block(x,3,[128,128,512])
x = conv_block(x,3,[256,256,1024])
x = identity_block(x,3,[256,256,1024])
x = identity_block(x,3,[256,256,1024])
x = identity_block(x,3,[256,256,1024])
x = identity_block(x,3,[256,256,1024])
x = identity_block(x,3,[256,256,1024])
x = conv_block(x,3,[512,512,2048])
x = identity_block(x,3,[512,512,2048])
x = identity_block(x,3,[512,512,2048])
x = AveragePooling2D((7,7))(x)
x = tf.keras.layers.Flatten()(x)
output_ten = Dense(nb_class,activation='softmax')(x)
model = Model(input_ten,output_ten)
model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")
return model
model_ResNet50 = ResNet50(24,(img_height,img_width,3))
model_ResNet50.summary()
2、Pytorch代码
from torch import nn
class ConvBlock(nn.Module):
def __init__(self, in_channel, kernel_size, filters, stride):
super(ConvBlock, self).__init__()
filter1, filter2, filter3 = filters
self.stage = nn.Sequential(
nn.Conv2d(in_channel, filter1, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(filter1),
nn.RuLU(True),
nn.Conv2d(filter1, filter2, kernel_size, stride=1, padding=True, bias=False),
nn.BatchNorm2d(filter2),
nn.RuLU(True),
nn.Conv2d(filter2, filter3, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(filter3),
)
self.shortcut_1 = nn.Conv2d(in_channel, filter3, 1, stride=stride, padding=0, bias=False)
self.batch_1 = nn.BatchNorm2d(filter3)
self.relu_1 = nn.ReLU(True)
def forward(self, x):
x_shortcut = self.shortcut_1(x)
x_shortcut = self.batch_1(x_shortcut)
x = self.stage(x)
x = x + x_shortcut
x = self.relu_1(x)
return x
class IndentityBlock(nn.Module):
def __init__(self, in_channel, kernel_size, filters):
super(IndentityBlock, self).__init__()
filter1, filter2, filter3 = filters
self.stage = nn.Sequential(
nn.Conv2d(in_channel, filter1, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(filter1),
nn.RuLU(True),
nn.Conv2d(filter1, filter2, kernel_size, padding=True, bias=False),
nn.BatchNorm2d(filter1),
nn.RuLU(True),
nn.Conv2d(filter2, filter3, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(filter3),
)
self.relu_1=nn.ReLU(True)
def forward(self, x):
x_shortcut = x
x = self.stage(x)
x = x + x_shortcut
x = self.relu_1(x)
return x
class ResModel(nn.Module):
def __init__(self, n_class):
super(ResModel, self).__init__()
self.stage1 = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.MaxPool2d(3, 2, padding=1),
)
self.stage2 = nn.Sequential(
ConvBlock(64, f=3, filters=[64, 64, 256], s=2),
IndentityBlock(256, 3, [64, 64, 256]),
IndentityBlock(256, 3, [64, 64, 256]),
)
self.stage3 = nn.Sequential(
ConvBlock(256, f=3, filters=[128, 128, 512], s=3),
IndentityBlock(512, 3, [128, 128, 512]),
IndentityBlock(512, 3, [128, 128, 512]),
IndentityBlock(512, 3, [128, 128, 512]),
)
self.stage4 = nn.Sequential(
ConvBlock(512, f=3, filters=[256, 256, 1024], s=4),
IndentityBlock(1024, 3, [256, 256, 1024]),
IndentityBlock(1024, 3, [256, 256, 1024]),
IndentityBlock(1024, 3, [256, 256, 1024]),
IndentityBlock(1024, 3, [256, 256, 1024]),
IndentityBlock(1024, 3, [256, 256, 1024]),
)
self.stage5 = nn.Sequential(
ConvBlock(1024, f=3, filters=[512, 512, 2048], s=5),
IndentityBlock(2048, 3, [512, 512, 2048]),
IndentityBlock(2048, 3, [512, 512, 2048]),
)
self.pool = nn.AvgPool2d(7, 7, padding=1)
self.fc = nn.Sequential(
nn.Linear(8192, n_class)
)
def forward(self, X):
out = self.stage1(X)
out = self.stage2(out)
out = self.stage3(out)
out = self.stage4(out)
out = self.stage5(out)
out = self.pool(out)
out = out.view(out.size(0), 8192)
out = self.fc(out)
return out
3. 查看模型详情
#显示网络结构
import torchsummary
torchsummary.summary(model,(3,224,224))
运行结果:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 55, 55] 0
Conv2d-5 [-1, 64, 55, 55] 4,096
BatchNorm2d-6 [-1, 64, 55, 55] 128
ReLU-7 [-1, 64, 55, 55] 0
Conv2d-8 [-1, 64, 55, 55] 36,864
BatchNorm2d-9 [-1, 64, 55, 55] 128
ReLU-10 [-1, 64, 55, 55] 0
Conv2d-11 [-1, 256, 55, 55] 16,384
BatchNorm2d-12 [-1, 256, 55, 55] 512
Conv2d-13 [-1, 256, 55, 55] 16,384
BatchNorm2d-14 [-1, 256, 55, 55] 512
ReLU-15 [-1, 256, 55, 55] 0
ConvBlock-16 [-1, 256, 55, 55] 0
Conv2d-17 [-1, 64, 55, 55] 16,384
BatchNorm2d-18 [-1, 64, 55, 55] 128
ReLU-19 [-1, 64, 55, 55] 0
Conv2d-20 [-1, 64, 55, 55] 36,864
BatchNorm2d-21 [-1, 64, 55, 55] 128
ReLU-22 [-1, 64, 55, 55] 0
Conv2d-23 [-1, 256, 55, 55] 16,384
BatchNorm2d-24 [-1, 256, 55, 55] 512
ReLU-25 [-1, 256, 55, 55] 0
IdentityBlock-26 [-1, 256, 55, 55] 0
Conv2d-27 [-1, 64, 55, 55] 16,384
BatchNorm2d-28 [-1, 64, 55, 55] 128
ReLU-29 [-1, 64, 55, 55] 0
Conv2d-30 [-1, 64, 55, 55] 36,864
BatchNorm2d-31 [-1, 64, 55, 55] 128
ReLU-32 [-1, 64, 55, 55] 0
Conv2d-33 [-1, 256, 55, 55] 16,384
BatchNorm2d-34 [-1, 256, 55, 55] 512
ReLU-35 [-1, 256, 55, 55] 0
IdentityBlock-36 [-1, 256, 55, 55] 0
Conv2d-37 [-1, 128, 28, 28] 32,768
BatchNorm2d-38 [-1, 128, 28, 28] 256
ReLU-39 [-1, 128, 28, 28] 0
Conv2d-40 [-1, 128, 28, 28] 147,456
BatchNorm2d-41 [-1, 128, 28, 28] 256
ReLU-42 [-1, 128, 28, 28] 0
Conv2d-43 [-1, 512, 28, 28] 65,536
BatchNorm2d-44 [-1, 512, 28, 28] 1,024
Conv2d-45 [-1, 512, 28, 28] 131,072
BatchNorm2d-46 [-1, 512, 28, 28] 1,024
ReLU-47 [-1, 512, 28, 28] 0
ConvBlock-48 [-1, 512, 28, 28] 0
Conv2d-49 [-1, 128, 28, 28] 65,536
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
Conv2d-52 [-1, 128, 28, 28] 147,456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
Conv2d-55 [-1, 512, 28, 28] 65,536
BatchNorm2d-56 [-1, 512, 28, 28] 1,024
ReLU-57 [-1, 512, 28, 28] 0
IdentityBlock-58 [-1, 512, 28, 28] 0
Conv2d-59 [-1, 128, 28, 28] 65,536
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
Conv2d-62 [-1, 128, 28, 28] 147,456
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
Conv2d-65 [-1, 512, 28, 28] 65,536
BatchNorm2d-66 [-1, 512, 28, 28] 1,024
ReLU-67 [-1, 512, 28, 28] 0
IdentityBlock-68 [-1, 512, 28, 28] 0
Conv2d-69 [-1, 128, 28, 28] 65,536
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
Conv2d-72 [-1, 128, 28, 28] 147,456
BatchNorm2d-73 [-1, 128, 28, 28] 256
ReLU-74 [-1, 128, 28, 28] 0
Conv2d-75 [-1, 512, 28, 28] 65,536
BatchNorm2d-76 [-1, 512, 28, 28] 1,024
ReLU-77 [-1, 512, 28, 28] 0
IdentityBlock-78 [-1, 512, 28, 28] 0
Conv2d-79 [-1, 256, 14, 14] 131,072
BatchNorm2d-80 [-1, 256, 14, 14] 512
ReLU-81 [-1, 256, 14, 14] 0
Conv2d-82 [-1, 256, 14, 14] 589,824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
Conv2d-85 [-1, 1024, 14, 14] 262,144
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
Conv2d-87 [-1, 1024, 14, 14] 524,288
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
ReLU-89 [-1, 1024, 14, 14] 0
ConvBlock-90 [-1, 1024, 14, 14] 0
Conv2d-91 [-1, 256, 14, 14] 262,144
BatchNorm2d-92 [-1, 256, 14, 14] 512
ReLU-93 [-1, 256, 14, 14] 0
Conv2d-94 [-1, 256, 14, 14] 589,824
BatchNorm2d-95 [-1, 256, 14, 14] 512
ReLU-96 [-1, 256, 14, 14] 0
Conv2d-97 [-1, 1024, 14, 14] 262,144
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
ReLU-99 [-1, 1024, 14, 14] 0
IdentityBlock-100 [-1, 1024, 14, 14] 0
Conv2d-101 [-1, 256, 14, 14] 262,144
BatchNorm2d-102 [-1, 256, 14, 14] 512
ReLU-103 [-1, 256, 14, 14] 0
Conv2d-104 [-1, 256, 14, 14] 589,824
BatchNorm2d-105 [-1, 256, 14, 14] 512
ReLU-106 [-1, 256, 14, 14] 0
Conv2d-107 [-1, 1024, 14, 14] 262,144
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
ReLU-109 [-1, 1024, 14, 14] 0
IdentityBlock-110 [-1, 1024, 14, 14] 0
Conv2d-111 [-1, 256, 14, 14] 262,144
BatchNorm2d-112 [-1, 256, 14, 14] 512
ReLU-113 [-1, 256, 14, 14] 0
Conv2d-114 [-1, 256, 14, 14] 589,824
BatchNorm2d-115 [-1, 256, 14, 14] 512
ReLU-116 [-1, 256, 14, 14] 0
Conv2d-117 [-1, 1024, 14, 14] 262,144
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
ReLU-119 [-1, 1024, 14, 14] 0
IdentityBlock-120 [-1, 1024, 14, 14] 0
Conv2d-121 [-1, 256, 14, 14] 262,144
BatchNorm2d-122 [-1, 256, 14, 14] 512
ReLU-123 [-1, 256, 14, 14] 0
Conv2d-124 [-1, 256, 14, 14] 589,824
BatchNorm2d-125 [-1, 256, 14, 14] 512
ReLU-126 [-1, 256, 14, 14] 0
Conv2d-127 [-1, 1024, 14, 14] 262,144
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
ReLU-129 [-1, 1024, 14, 14] 0
IdentityBlock-130 [-1, 1024, 14, 14] 0
Conv2d-131 [-1, 256, 14, 14] 262,144
BatchNorm2d-132 [-1, 256, 14, 14] 512
ReLU-133 [-1, 256, 14, 14] 0
Conv2d-134 [-1, 256, 14, 14] 589,824
BatchNorm2d-135 [-1, 256, 14, 14] 512
ReLU-136 [-1, 256, 14, 14] 0
Conv2d-137 [-1, 1024, 14, 14] 262,144
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
ReLU-139 [-1, 1024, 14, 14] 0
IdentityBlock-140 [-1, 1024, 14, 14] 0
Conv2d-141 [-1, 512, 7, 7] 524,288
BatchNorm2d-142 [-1, 512, 7, 7] 1,024
ReLU-143 [-1, 512, 7, 7] 0
Conv2d-144 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-145 [-1, 512, 7, 7] 1,024
ReLU-146 [-1, 512, 7, 7] 0
Conv2d-147 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-148 [-1, 2048, 7, 7] 4,096
Conv2d-149 [-1, 2048, 7, 7] 2,097,152
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096
ReLU-151 [-1, 2048, 7, 7] 0
ConvBlock-152 [-1, 2048, 7, 7] 0
Conv2d-153 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-154 [-1, 512, 7, 7] 1,024
ReLU-155 [-1, 512, 7, 7] 0
Conv2d-156 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-157 [-1, 512, 7, 7] 1,024
ReLU-158 [-1, 512, 7, 7] 0
Conv2d-159 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-160 [-1, 2048, 7, 7] 4,096
ReLU-161 [-1, 2048, 7, 7] 0
IdentityBlock-162 [-1, 2048, 7, 7] 0
Conv2d-163 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-164 [-1, 512, 7, 7] 1,024
ReLU-165 [-1, 512, 7, 7] 0
Conv2d-166 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-167 [-1, 512, 7, 7] 1,024
ReLU-168 [-1, 512, 7, 7] 0
Conv2d-169 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-170 [-1, 2048, 7, 7] 4,096
ReLU-171 [-1, 2048, 7, 7] 0
IdentityBlock-172 [-1, 2048, 7, 7] 0
AvgPool2d-173 [-1, 2048, 1, 1] 0
Linear-174 [-1, 4] 8,196
================================================================
Total params: 23,516,228
Trainable params: 23,516,228
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 270.43
Params size (MB): 89.71
Estimated Total Size (MB): 360.71
----------------------------------------------------------------
3、 训练模型
- 编写训练函数
def train(dataloader,model,loss_fn,optimizer):
size=len(dataloader.dataset) #训练集的大小
num_batches=len(dataloader) #批次数目
train_loss,train_acc=0,0 #初始化训练损失和正确率
for x,y in dataloader: #获取图片及其标签
x,y=x.to(device),y.to(device)
#计算预测误差
pred=model(x) #网络输出
loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距,二者差值即为损失
#反向传播
optimizer.zero_grad() #grad属性归零
loss.backward() #反向传播
optimizer.step() #每一步自动更新
#记录acc与loss
train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
train_loss+=loss.item()
train_acc/=size
train_loss/=num_batches
return train_acc,train_loss
- 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
#测试函数
def test(dataloader,model,loss_fn):
size=len(dataloader.dataset) #测试集的大小
num_batches=len(dataloader) #批次数目
test_loss,test_acc=0,0
#当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad():
for imgs,target in dataloader:
imgs,target=imgs.to(device),target.to(device)
#计算loss
target_pred=model(imgs)
loss=loss_fn(target_pred,target)
test_loss+=loss.item()
test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
test_acc/=size
test_loss/=num_batches
return test_acc,test_loss
- 正式训练
import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4) #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss() #创建损失函数
epochs=10
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标
for epoch in range(epochs):
model.train()
epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
#保存最佳模型到J1_model
if epoch_test_acc>best_acc:
best_acc=epoch_test_acc
J1_model=copy.deepcopy(model)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
#获取当前学习率
lr=optimizer.state_dict()['param_groups'][0]['lr']
template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
epoch_test_acc*100,epoch_test_loss,lr))
#保存最佳模型到文件中
PATH=r'D:\data\J-series\J1_model.pth'
torch.save(model.state_dict(),PATH)
运行结果:
Epoch: 1,Train_acc:52.0%,Train_loss:1.178,Test_acc:43.4%,Test_loss:2.180,Lr:1.00E-04
Epoch: 2,Train_acc:68.1%,Train_loss:0.836,Test_acc:76.1%,Test_loss:0.952,Lr:1.00E-04
Epoch: 3,Train_acc:78.5%,Train_loss:0.664,Test_acc:77.9%,Test_loss:0.635,Lr:1.00E-04
Epoch: 4,Train_acc:80.5%,Train_loss:0.513,Test_acc:58.4%,Test_loss:1.794,Lr:1.00E-04
Epoch: 5,Train_acc:84.7%,Train_loss:0.416,Test_acc:75.2%,Test_loss:0.755,Lr:1.00E-04
Epoch: 6,Train_acc:82.5%,Train_loss:0.555,Test_acc:78.8%,Test_loss:0.734,Lr:1.00E-04
Epoch: 7,Train_acc:85.0%,Train_loss:0.399,Test_acc:64.6%,Test_loss:1.196,Lr:1.00E-04
Epoch: 8,Train_acc:88.1%,Train_loss:0.372,Test_acc:43.4%,Test_loss:4.219,Lr:1.00E-04
Epoch: 9,Train_acc:87.8%,Train_loss:0.319,Test_acc:90.3%,Test_loss:0.375,Lr:1.00E-04
Epoch:10,Train_acc:95.1%,Train_loss:0.166,Test_acc:92.0%,Test_loss:0.321,Lr:1.00E-04
Epoch:11,Train_acc:90.5%,Train_loss:0.263,Test_acc:84.1%,Test_loss:0.422,Lr:1.00E-04
Epoch:12,Train_acc:88.9%,Train_loss:0.310,Test_acc:90.3%,Test_loss:0.404,Lr:1.00E-04
Epoch:13,Train_acc:93.1%,Train_loss:0.190,Test_acc:89.4%,Test_loss:0.489,Lr:1.00E-04
Epoch:14,Train_acc:90.5%,Train_loss:0.282,Test_acc:81.4%,Test_loss:0.456,Lr:1.00E-04
Epoch:15,Train_acc:93.6%,Train_loss:0.181,Test_acc:85.8%,Test_loss:0.512,Lr:1.00E-04
Epoch:16,Train_acc:96.9%,Train_loss:0.100,Test_acc:92.0%,Test_loss:0.256,Lr:1.00E-04
Epoch:17,Train_acc:97.8%,Train_loss:0.096,Test_acc:89.4%,Test_loss:0.294,Lr:1.00E-04
Epoch:18,Train_acc:91.4%,Train_loss:0.260,Test_acc:85.8%,Test_loss:0.641,Lr:1.00E-04
Epoch:19,Train_acc:95.8%,Train_loss:0.139,Test_acc:90.3%,Test_loss:0.534,Lr:1.00E-04
Epoch:20,Train_acc:95.4%,Train_loss:0.157,Test_acc:89.4%,Test_loss:0.459,Lr:1.00E-04
5、 结果可视化
- Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
运行结果:
可以看出中间有明显的起伏波动,再次修改 batch_size=16,尝试后结果如下所示:
- 指定图片进行预测
from PIL import Image
classes=list(total_data.class_to_idx)
def predict_one_image(image_path,model,transform,classes):
test_img=Image.open(image_path).convert('RGB')
plt.imshow(test_img) #展示预测的图片
test_img=transform(test_img)
img=test_img.to(device).unsqueeze(0)
model.eval()
output=model(img)
_,pred=torch.max(output,1)
pred_class=classes[pred]
print(f'预测结果是:{pred_class}')
预测图片:
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\data\J-series\J1\bird_photos\Black Skimmer\001.jpg',
model=model,transform=train_transforms,classes=classes)
运行结果:
预测结果是:Black Skimmer
- 模型评估
J1_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J1_model,loss_fn)
epoch_test_acc,epoch_test_loss
运行结果:
(0.7787610619469026, 0.8548152595758438)
六、心得体会
本周项目训练中,在pytorch环境下手动搭建了resnet50模型,与上节课相比,更加深层的理解了该模型的构造原理,对该模型有了更深层次的感悟。但模型的训练结果中,测试集的acc和loss都出现了较大的震荡,虽然多次修改batch_size,也通过图形旋转、翻转等方法对数据进行增强,但结果仍然不尽人意。猜引起该结果可能是由于数据集过小造成的,留待今后验证。