Bootstrap

在手写数字识别MNIST上实现模型固定间隔的保存以及意外中断的恢复训练!

前言:本文没有过多注释,注释都写在代码里了。

作者作为深度学习半脚入门玩家,初心是想实现1.每训练一定的epoch保存一下模型权重;2.训练意外中断时从保存的最近检查点恢复训练;这两个目的。所以就从最简单的CNN手写数字识别任务上做了尝试,这份代码应该是直接copy就能跑通的,大家可以训一半epoch自己中断一下(ctrl+c)再重新训试试,应该是可以从保存的最近检查点(ckpt)恢复训练的,看在作业这么诚恳地交流心得且代码完全copy下来就能跑通的前提下点赞收藏关注下吧!

祝愿我们都越来越好!

【关于恢复训练参考的是:PyTorch模型保存深入理解 - 简书 (jianshu.com)

'''pytorch深度学习框架中模型权重的保存与恢复训练 '''

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import partial
from tqdm import tqdm
from pathlib import Path
import argparse
import re
import os


# 搭建LeNet5模型网络
class LeNet5(nn.Module):
    # 初始化方法
    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        # 卷积层1
        self.layer_1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(num_features=6),
            nn.ReLU(),
        )
        # 下采样层1
        self.subsample_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层2
        self.layer_2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
        )
        # 下采样层2
        self.subsample_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层与激活层
        self.fc_1 = nn.Linear(in_features=400, out_features=120)
        self.relu_1 = nn.ReLU()
        self.fc_2 = nn.Linear(in_features=120, out_features=84)
        self.relu_2 = nn.ReLU()
        self.fc_3 = nn.Linear(in_features=84, out_features=n_classes)

    # 前向传播方法
    def forward(self, x):
        x = self.layer_1(x)
        x = self.subsample_1(x)
        x = self.layer_2(x)
        x = self.subsample_2(x)

        # 将上一步输出的16个5×5特征图中的400个像素展平成一维向量,以便下一步全连接
        x = x.reshape(x.shape[0], -1)

        # 进入全连接层
        x = self.fc_1(x)
        x = self.relu_1(x)
        x = self.fc_2(x)
        x = self.relu_2(x)
        out = self.fc_3(x)
        return out


# 加载手写数字图像训练集和测试集
train_dataset = torchvision.datasets.MNIST(root="./MNIST_data",
                                           train=True,
                                           transform=transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root="./MNIST_data",
                                          train=False,
                                          transform=transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1325,),
                                                  std = (0.3105,))]),
                                          download=True)


# 定义训练时的命令函数
def train_opt():
    # 创建一个命令对象
    parser = argparse.ArgumentParser()
    parser.add_argument("--project", default="train_MNIST", help="project/name")
    parser.add_argument("--exp_name", default="exp", help="ckpt_save_dir")
    parser.add_argument("--batch_size", default=80, help="batch_size")
    parser.add_argument("--epochs", default=20, help="epochs")
    parser.add_argument("--learning_rate", default=1e-3, help="learning_rate")
    parser.add_argument("--save_interval", type=int, default=1, help="save ckpt every save_interval epoch")
    opt = parser.parse_args()
    return opt


def train(opt, epoch, w_dir, train=True):
    # 定义设备 mostly for cuda
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 定义模型
    model = LeNet5(n_classes=10)
    # train_loader
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)
    # 定义损失函数
    loss_fuc = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)

    # 预备训练
    load_loop = partial(tqdm, position=1, desc="Batch")

    # 初始化loss为0
    train_loss = 0.0
    for data, target in load_loop(train_data_loader):
        # data, target = data.to(device), target.to(device)
        # 清除优化器累计梯度信息
        optimizer.zero_grad()
        output = model(data)
        # 更新loss
        loss = loss_fuc(output, target)
        loss.backward()   # 反向传播
        optimizer.step()  # 更新网络参数
        train_loss += loss

    # 每opt.save_interval个保存一下权重
    w_dir.mkdir(parents=True, exist_ok=True)
    if (epoch % opt.save_interval) == 0:
        ckpt = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        torch.save(ckpt, os.path.join(w_dir, f"train-{epoch}.pt"))

    train_loss = train_loss / len(train_data_loader.dataset)
    if train:
        print("Epoch: {} \tTraining loss: {:.6f}".format(epoch, train_loss.item()))
        test(opt, model)  # 将更新参数的model传入test()方法进行测试

def test(opt, model):
    # 定义设备 mostly for cuda
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # test data loader
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=2)
    # 初始化准确率为0
    correct = 0
    total = 0
    # 禁用梯度运算
    with torch.no_grad():
        for image, label in test_data_loader:
            # image, label = image.to(device), label.to(device)
            output = model(image)
            _, predict = torch.max(output.data, 1)
            total += label.shape[0]
            correct += (predict == label).sum().item()
            error = 1 - correct / total
        # 测试完毕,输出正确率
        print("Accuracy of the network on the test images: {:.4f}%".format(100.0000 * correct / total))
        # 返回误差
        return error


# 定义一个函数,用于提取字符串中的数字,并从小到大排序
def extract_number(s):
    return int(re.search(r'\d+', s).group())


def check_pt_files(directory):
    # 遍历给定目录下的所有文件及子目录,判断给定目录的文件是否非空,当非空返回True以及最近检查点的路径组成的列表
    for root, dirs, files in os.walk(directory):
        start_pt_path = os.path.join(root, sorted(files, key=extract_number)[-1])
        if os.path.getsize(start_pt_path) > 0:
            list = [True, start_pt_path]
            return list
        else:
            return False


def main(opt):
    save_dir = str((Path(opt.project) / opt.exp_name))
    save_dir = Path(save_dir)
    w_dir = save_dir / "weights"
    if os.path.isdir(w_dir):
        if check_pt_files(w_dir)[0]:
            start_epoch_path = check_pt_files(w_dir)[1]
            start_epoch_idx = str(os.path.splitext(os.path.basename(start_epoch_path))[0]).split("-")[-1]
            start_epoch_idx = int(start_epoch_idx)
            # 从保存的最近检查点恢复训练,其中ckpt的键值对要跟torch.save()保存的一致!
            print(f"Recovery training from the {start_epoch_idx}-th epoch!")
            start_ckpt = torch.load(start_epoch_path)
            model = LeNet5(n_classes=10)
            model.load_state_dict(start_ckpt['model_state_dict'])
            optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)
            optimizer.load_state_dict(start_ckpt["optimizer_state_dict"])
            # 打印模型参数总量
            print("model has {} parameters".format(sum(y.numel() for y in model.parameters())))
            for epoch in range(start_epoch_idx, opt.epochs+1):
                train(opt, epoch, w_dir=w_dir)
    else:
        print("Training from scratch!")
        # 打印模型参数总量
        model = LeNet5(n_classes=10)
        print("model has {} parameters".format(sum(y.numel() for y in model.parameters())))
        # 从头开始训练
        for epoch in range(1, opt.epochs+1):
            train(opt, epoch, w_dir=w_dir)


if __name__ == "__main__":
    opt = train_opt()
    main(opt)


参考文献:PyTorch模型保存深入理解 - 简书 (jianshu.com)

;