Bootstrap

用Pytorch实现图像分类

概述

   本文记录使用pytorch深度学习框架来实现一个图像分类任务的过程,搭建一个图像分类的工程主要包括以下步骤:
1.定义一个加载数据的类
2.构建网络结构
3.编写训练代码
4.编写验证/测试代码

一、 定义数据类

   pytorch中提供了两个类用于训练数据的加载,分别是torch.utils.data.Dataset和 torch.utils.data.DataLoader。
   Dataset类代表了数据集,任何我们自己设计的数据集类都应该是这个类的子类,我们需要继承这个类,并重写 len() 方法,这个方法是用来获得数据集的大小,和__getitem__()方法,这个方法用来返回数据集中索引值为0到len(dataset)的元素。
def getitem(self, index): 实现这个函数,就可以通过索引值来返回训练样本数据
def len(self): 实现这个函数,返回数据集的大小

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

   这里我们需要做的事情就是重写“getitem(self, index):” 和 “len(self): ”方法,我的实现如下:

import torch
from torch.utils.data import DataLoader, Dataset
import glob
import os
from PIL import Image
from torchvision import models, transforms

class MyDataSet(Dataset):
    def __init__(self, root_dir='./data', train_val='train', transform=None):
        self.data_path = os.path.join(root_dir, train_val)
        self.image_names = glob.glob(self.data_path + '/*/*.jpg')
        self.data_transform = transform
        self.train_val = train_val
        #print(self.image_names[0])

    def __len__(self):
        return(len(self.image_names))

    def __getitem__(self, item):
        img_path = self.image_names[item]
        #print(img_path)
        img = Image.open(img_path)
        # print(img.size)
        image = img
        label = img_path.split('/')[-2]
        label = int(label)
        if self.data_transform is not None:
            try:
                image = self.data_path
                image = self.data_transform[self.train_val](img)
            except:
                print('can not load image:{}'.format(img_path))
        return image, label

   初始化函数里面一般需要读入所有训练样本的路径和类别信息(根据自己数据实际存放的方式读取);“getitem(self,item)”方法中"item"表示索引值,这个方法需要实返回的是索引对应的一个样本(包括图像和标签),这里可以选择"PIL.Image.open()"或是"cv2.imread()"等方法读取图像;另外我的代码中还包含了“transforms”,这个类是用来进行数据增广,我的“transforms”定义如下:

data_transforms = {
        'train': transforms.Compose([
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            # transforms.Scale(256),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

   其中用到了“随机裁剪”、“镜像”、“归一化”等操作。
   仅仅有通过索引返回训练数据数不够的,我们还需要DataLoader类提供拓展功能。

dataloaders_train = DataLoader(
        image_datasets['train'],
        batch_size=32,
        shuffle=True
    )

   这个类我们不需要实现代码,设置好参数直接调用就行了,在模型训练之前可以测试下以上过程是否有误:

data = iter(dataloaders_train)
for i in range(1):
    print(next(data))

   如果可以成功输出一个batch的数据,则可以继续后面的步骤了。

二、定义网络结构

   由于本文的主要目的是利用pytorch框架快速搭建一个图像分类的工程,在网络的构建上没有自己实现,而是直接采用了“torchvison.models”里面自带的Resnet50的模型,直接调用这种模型除了方便我们快速搭建工程以外,还有一个好处是可以加载它的预训练参数,在此基础上fine-tune会比从头训练效果更好,尤其是在我们的样本量不足的情况下。

model = models.resnet50(pretrained=True)# 这里加载预训练参数来fine-tune
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)#由于我们需要的类别数量可能不一样,所以需要重新定义一下最后用于分类的全连层

三、实现训练过程

   有了前面定义好的数据类和模型以外,要实现模型的训练过程我们还需要定义loss函数、优化器等。

def train_model(model, data_sizes, num_epochs, scheduler, dataloaders,criterion, optimizer, ):
"""
data_sizes 为样本总量
num_epochs 为总共迭代的轮数,一个epoch表示把整个训练集走一遍
scheduler 用来控制学习率
dataloaders 加载数据
criterion loss函数
optimizer 优化器
"""
    device1 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    since = time.time()
    best_model_wts = model.state_dict()
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        begin_time = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('*'*20)

        for phase in ['train', 'val']:
            count_batch = 0
            if phase == 'train':
                scheduler.step()
                model.train(True)
            else:
                model.train(False)
            running_loss = 0.0
            running_corrects = 0.0
            for i, data in enumerate(dataloaders[phase]):
                count_batch += 1
                inputs, labels = data
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                optimizer.zero_grad()
                outputs = model(inputs)
                out = torch.argmax(outputs.data, 1)
                #print(torch.argmax(outputs.data, 1))
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                running_loss += loss.data
                running_corrects += torch.sum(preds == labels.data).to(torch.float32)

                if count_batch % 10==0:
                    batch_loss = running_loss / (batch_size * count_batch)
                    batch_acc = running_corrects / (batch_size * count_batch)
                    print('{} Epoch [{}] Batch Loss: {:.4f} Acc:{:.4f} Time: {:.4f}s'.format(
                        phase, epoch, batch_loss, batch_acc, time.time()-begin_time
                    ))
                    begin_time = time.time()
        epoch_loss = running_loss / data_sizes[phase]
        epoch_acc = running_corrects / data_sizes[phase]
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        if phase == 'train':
            if not os.path.exists(model_path):
                os.mkdir(model_path)
            torch.save(model, os.path.join(model_path, 'resnet_epoch{}.pkl').format(epoch))

        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

        time_elapsed = time.time() - since
        print('Training completed in {:.0f}mins {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:.4f}'.format(best_acc))

        model.load_state_dict(best_model_wts)
    return(model)

四、 预测

   预测的demo很简单,加载训练好的模型,读取输入图像并转化为tensor送入模型计算。

model = torch.load(model_path)
use_cuda = True if torch.cuda.is_available else False
if use_cuda:
    model.cuda()
label, score = predict(model, img)

def predict(use_cuda, model, image_name):
    test_image = Image.open(image_name)
    test_image_tensor = transform(test_image)
    if use_cuda:
        test_image_tensor = test_image_tensor.cuda()
    else:
        test_image_tensor = test_image_tensor
    test_image_tensor = Variable(torch.unsqueeze(test_image_tensor, dim=0).float(), requires_grad=False)

    with torch.no_grad():
        model.eval()
        #print(model)
        out = model(test_image_tensor)
        ps = torch.exp(out)  #Softmax操作,转换成概率分布
        ps = ps / torch.sum(ps)
        topk, topclass = ps.topk(1, dim=1)
        return(idx_to_class[topclass.cpu().numpy()[0][0]], topk.cpu().numpy()[0][0])

其中“idx_to_class”为用于index到label名的映射。
完整的项目代码详见github:Image_Classify_pytorch

;