Bootstrap

U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集

前言

  • 博客很久没有更新了,今天就来更新一篇博客吧,哈哈;
  • 最近在做图像分割相关的任务,因此,写这么一篇博客来简单实现一下分割是怎么做的,内容简单,枯燥,需要耐心看,哈哈;
  • 博客的内容相对简单,比较适合刚接触分割的同学参考学习(这篇博客在算法训练上没有涉及到训练策略、数据增强方法,特意留下余地处给大家自行发挥)

内容简介

  • U2Net算法介绍
  • 本博客训练效果截图展示
  • 本博客代码框架介绍
  • 数据集数据集准备
  • 自定义dataset
  • u2net、u2netp网络结构定义
  • 训练代码
  • 模型推理代码
  • 总结以及博客代码的Github地址

U2Net算法介绍

在这里插入图片描述

  • 关于算法介绍,CSDN上很多大神有详细的解读,大家可自行去搜索阅读学习,本博客目的是实操,所以此处省略上千字,哈哈
  • 官方论文地址:https://arxiv.org/pdf/2005.09007.pdf
  • 官方Github repo 地址:https://github.com/xuebinqin/U-2-Net

本博客代码训练效果截图展示

  • 任务图片分割结果可视化展示
  • 如上图所示,模型在测试集上的推理效果(左上为原始标注mask,左下为预测的mask,右边图像为原始图片)可以看出,模型的效果还是比较理想的;

代码框架介绍

  • 项目的整体框架如下图所示
    在这里插入图片描述
  • 第一个Folder :backup
backup为训练过程模型的保存的folder,在训练过程中,代码会自动在该目录下生成文件夹,并保存训练过程的权重pth文件
  • 第二个Folder: dataset

dataset目录为训练数据集存放的目录包括了参与训练的原始图片、以及对应的标注mask,训练数据集的组成方式由图片由如下的方式组成:
-images
   -train
     -0.jpg
     -1.jpg
     -....
   -test
  	 -0.jpg
     -1.jpg
     -....
   -val
  	 -0.jpg
     -1.jpg
     -....
 -masks
   -train
     -0.jpg
     -1.jpg
     -....
   -test
  	 -0.jpg
     -1.jpg
     -....
   -val
  	 -0.jpg
     -1.jpg
     -....
  • 第三个Folder:src
src文件夹下有两个文件,一个是网络模型的定义文件u2net.py,另一个为自定义的dataset.py
  • train_u2net.py文件: 模型训练代码
  • inference_u2net.py文件: 模型的推理代码

训练数据集准备

  • 请参考上一章节中dataset Folder的描述方式来准备您的训练数据集;
  • 注意:请保持原始图片和mask图片的命名一致,若不一致的话,需自行修改调整dataset代码部分

自定义dataset

  • 一般来说dataset的组成部分有核心的两个
__getitiem__ 方法 (根据索引返回样本数据)
__len__ 方法 (返回数据集中样本的个数)
(注意:本博客中dataset类中,未写数据增强部分,特意给大家留下空间自行学习和发挥)
  • 根据上述描述,接下来我们开始自定义dataset
    src/seg_dataset.py
# coding: utf-8
# author: hxy
# 2022-04-20
"""
数据读取dataset
"""
import os
import cv2
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset


# dataset for u2net
class U2netSegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, input_size=(320, 320)):
        """
        :param img_dir: 数据集图片文件夹路径
        :param mask_dir: 数据集mask文件夹路径
        :param input_size: 图片输入的尺寸
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.input_size = input_size
        self.samples = list()
        self.gt_mask = list()
        self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
        self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
        self.load_data()

    def __len__(self):
        return len(self.samples)

    def load_data(self):
        img_dir_full_path = self.img_dir
        mask_dir_full_path = self.mask_dir
        img_files = os.listdir(img_dir_full_path)

        for img_name in tqdm(img_files):
            img_full_path = os.path.join(img_dir_full_path, img_name)
            mask_full_path = os.path.join(mask_dir_full_path, img_name)

            img = cv2.imread(img_full_path)
            img = cv2.resize(img, self.input_size)
            img2rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img2norm = (img2rgb - self.mean) / self.std
            # 图像格式改为nchw
            img2nchw = np.transpose(img2norm, [2, 0, 1]).astype(np.float32)

            gt_mask = cv2.imread(mask_full_path)
            gt_mask = cv2.resize(gt_mask, self.input_size)
            gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
            gt_mask = gt_mask / 255.
            gt_mask = np.expand_dims(gt_mask, axis=0)

            self.samples.append(img2nchw)
            self.gt_mask.append(gt_mask)

        return self.samples, self.gt_mask

    def __getitem__(self, index):
        img = self.samples[index]
        mask = self.gt_mask[index]

        return img, mask

上面的代码块简单描述一下: 用os模块遍历文件夹,获取所有文件的名字,并将他们的全部路径拼接起来,opencv读取,然后对读取的照片array做预处理(resize、归一化、通道转换),最后将预处理好的图片append到对应的list中去即可;

u2net、u2netp网络结构定义

  • 网络结构的定义, 该部分代码是直接从源repo中copy过来的,所以直接贴在下来供大家参考使用;
    src/u2net.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class REBNCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        super(REBNCONV, self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self, x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout


## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
    # src = F.upsample(src, size=tar.shape[2:], mode='bilinear') # old version torch
    src = F.upsample(src, size=tar.shape[2:], mode='bilinear', align_corners=True)

    return src


### RSU-7 ###
class RSU7(nn.Module):  # UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
        hx6dup = _upsample_like(hx6d, hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-6 ###
class RSU6(nn.Module):  # UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-5 ###
class RSU5(nn.Module):  # UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-4 ###
class RSU4(nn.Module):  # UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-4F ###
class RSU4F(nn.Module):  # UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(U2NET, self).__init__()

        self.stage1 = RSU7(in_ch, 32, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 32, 128)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(128, 64, 256)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(256, 128, 512)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(512, 256, 512)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(512, 256, 512)

        # decoder
        self.stage5d = RSU4F(1024, 256, 512)
        self.stage4d = RSU4(1024, 128, 256)
        self.stage3d = RSU5(512, 64, 128)
        self.stage2d = RSU6(256, 32, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)

        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)

    def forward(self, x):
        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        # -------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)


### U^2-Net small ###
class U2NETP(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(U2NETP, self).__init__()

        self.stage1 = RSU7(in_ch, 16, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 16, 64)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(64, 16, 64)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(64, 16, 64)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(64, 16, 64)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(64, 16, 64)

        # decoder
        self.stage5d = RSU4F(128, 16, 64)
        self.stage4d = RSU4(128, 16, 64)
        self.stage3d = RSU5(128, 16, 64)
        self.stage2d = RSU6(128, 16, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)

        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)

    def forward(self, x):
        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        # decoder
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)

训练代码

  • 训练代码
深度学习训练代码的一般流程是: 模型定义 -> 数据加载 -> 模型训练 ->模型验证
本博客中训练代码的实现逻辑如下:
	1 定义网络
	2 加载数据
	3 定义损失函数和优化器
	 4 开始训练
	   - 训练网络
	   - 将梯度置为0
	   - 求loss
	   - 反向传播
	   - 更新参数
(在本博客的训练代码中未写验证部分代码,留给各位同学自行实现)

** train_u2net.py**

# coding: utf-8
# author: hxy
# 20220420
"""
训练代码:u2net、u2netp
train it from scratch.
"""
import os
import datetime
import torch
import numpy as np
from tqdm import tqdm
from src.u2net import U2NET, U2NETP
from src.seg_dataset import U2netSegDataset
from torch.utils.data import DataLoader

# 参考u2net源码loss的设定
bce_loss = torch.nn.BCELoss(reduction='mean')


def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    loss0 = bce_loss(d0, labels_v)
    loss1 = bce_loss(d1, labels_v)
    loss2 = bce_loss(d2, labels_v)
    loss3 = bce_loss(d3, labels_v)
    loss4 = bce_loss(d4, labels_v)
    loss5 = bce_loss(d5, labels_v)
    loss6 = bce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    # print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    # loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
    # loss6.data.item()))

    return loss0, loss


def load_data(img_folder, mask_folder, batch_size, num_workers, input_size):
    """
    :param img_folder: 图片保存的fodler
    :param mask_folder: mask保存的fodler
    :param batch_size: batch_size的设定
    :param num_workers: 数据加载cpu核心数
    :param input_size: 模型输入尺寸
    :return:
    """
    train_dataset = U2netSegDataset(img_dir=os.path.join(img_folder, 'train'),
                                    mask_dir=os.path.join(mask_folder, 'train'),
                                    input_size=input_size)

    val_dataset = U2netSegDataset(img_dir=os.path.join(img_folder, 'val'),
                                  mask_dir=os.path.join(mask_folder, 'val'),
                                  input_size=input_size)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader


def train_model(epoch_nums, cuda_device, model_save_dir):
    """
    :param epoch_nums: 训练总的epoch
    :param cuda_device: 指定gpu训练
    :param model_save_dir: 模型保存folder
    :return:
    """
    current_time = datetime.datetime.now()
    current_time = datetime.datetime.strftime(current_time, '%Y-%m-%d-%H:%M')
    model_save_dir = os.path.join(model_save_dir, current_time)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    else:
        pass

    device = torch.device(cuda_device)
    train_loader, val_loader = load_data(img_folder='dataset',
                                         mask_folder='dataset',
                                         batch_size=32,
                                         num_workers=10,
                                         input_size=(160, 160))

    # input 3-channels, output 1-channels
    net = U2NET(3, 1)
    #net = U2NETP(3, 1)
    
    # if torch.cuda.device_count() > 1:
    #     net = torch.nn.DataParallel(net, device_ids=[6, 7])
    net.to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    for epoch in range(0, epoch_nums):
        run_loss = list()
        run_tar_loss = list()

        net.train()
        for i, (inputs, gt_masks) in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()
            inputs = inputs.type(torch.FloatTensor)
            gt_masks = gt_masks.type(torch.FloatTensor)
            inputs, gt_masks = inputs.to(device), gt_masks.to(device)

            d0, d1, d2, d3, d4, d5, d6 = net(inputs)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, gt_masks)

            loss.backward()
            optimizer.step()

            run_loss.append(loss.item())
            run_tar_loss.append(loss2.item())
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

        print("--Train Epoch:{}--".format(epoch))
        print("--Train run_loss:{:.4f}--".format(np.mean(run_loss)))
        print("--Train run_tar_loss:{:.4f}--\n".format(np.mean(run_tar_loss)))

        if epoch % 20 == 0:
            checkpoint_name = 'checkpoint_' + str(epoch) + '_' + str(np.mean(run_loss)) + '.pth'
            torch.save(net.state_dict(), os.path.join(model_save_dir, checkpoint_name))
            print("--model saved:{}--".format(checkpoint_name))


if __name__ == '__main__':
    train_model(epoch_nums=500, cuda_device='cuda:7',
                model_save_dir='backup')

在这部分训练代码中, 并没有出现很多训练策略,如各种学习率调整策略、多阶段学习等等…该代码实现的为最基础的训练代码,因此,您有足够的空间去自行发挥;

模型推理程序

  • 算法模型推理
推理程序的编写逻辑一般是: 加载模型-> 读取图片 —>图片预处理(需要保持和训练过程中的图片预处理一致) ->模型推理 ->获取结果,进行后处理 ->保存图片,可视化查看结果

inference_u2net.py

# coding: utf-8
# author: hxy
# 20220420
"""
u2net/u2netP模型推理程序
"""

import os
import cv2
import torch
import numpy as np
from time import time
from tqdm import tqdm
from src.u2net import U2NET, U2NETP

"""
初始化模型加载
"""
try:
    print('===loading model===')
    current_project_path = os.getcwd()
    net = U2NET(3, 1)
    # net = U2NETP(3, 1)
    checkpoint_path = os.path.join(current_project_path,
                                   'backup/*****.pth')
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(checkpoint_path, map_location='cuda:1'))
    else:
        net.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    net.eval()
    print('===model lode sucessed===')

except Exception as e:
    print('===model load error:{}==='.format(e))


# 计算dice
def dice_coef(output, target):  # output为预测结果 target为真实结果
    smooth = 1e-5  # 防止0除
    intersection = (output * target).sum()
    return (2. * intersection + smooth) / \
           (output.sum() + target.sum() + smooth)


# 图像归一化操作
def img2norm(img_array, input_size):
    std = [0.229, 0.224, 0.225]
    mean = [0.485, 0.456, 0.406]
    _std = np.array(std).reshape((1, 1, 3))
    _mean = np.array(mean).reshape((1, 1, 3))

    img_array = cv2.resize(img_array, input_size)
    norm_img = (img_array - _mean) / _std

    return norm_img


# 归一化预测结果
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d - mi) / (ma - mi)

    return dn


# 推理
def inference1folder(img_folder, mask_folder, input_size):
    total_times = list()
    total_dices = list()
    img_files = os.listdir(img_folder)
    for img_file in tqdm(img_files):
        img_full_path = os.path.join(img_folder, img_file)
        mask_full_path = os.path.join(mask_folder, img_file)
        img = cv2.imread(img_full_path)
        gt_mask = cv2.imread(mask_full_path)
        gt_mask = cv2.resize(gt_mask, input_size)
        gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
        gt_mask = gt_mask / 255.

        ori_h, ori_w = img.shape[:2]
        img2rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        norm_img = img2norm(img2rgb, input_size)

        x_tensor = torch.from_numpy(norm_img).permute(2, 0, 1).float()
        x_tensor = torch.unsqueeze(x_tensor, 0)

        start_t = time()
        d1, d2, d3, d4, d5, d6, d7 = net(x_tensor)
        end_t = time()

        total_times.append(end_t - start_t)
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        pred = pred.squeeze().cpu().data.numpy()

        dice_value = dice_coef(pred, gt_mask)
        total_dices.append(dice_value)

        # pred[pred>=0.3]=255
        # pred[pred<0.3]=0
        # pred_res = pred
        pred_res = pred * 255
        pred_res = cv2.resize(pred_res, (ori_w, ori_h))

        cv2.imwrite(os.path.join(current_project_path, 'infer_output/', img_file), pred_res)

    print('==inference 1 pic avg cost:{:.4f}ms=='.format(np.mean(total_times) * 1000))
    print('==inference avg dice:{:.4f}=='.format(np.mean(total_dices)))

    return None


if __name__ == '__main__':
    test_img_folder = os.path.join(os.getcwd(), 'dataset/images/test')
    test_gt_mask_folder = os.path.join(os.getcwd(), 'dataset/masks/test')
    inference1folder(img_folder=test_img_folder, mask_folder=test_gt_mask_folder, input_size=(160, 160))

着一部分代码没什么好说的,仔细看就完事,当然我只写了针对于一个folder的推理代码,您可以尝试推理视频file;或者你也可以加一些更加炫酷的后处理让你的推理结果看起来更加具有美观;

总结以及博客代码的Github地址

  • 一篇博客写完总归还是要来点总结才完美的!
  1. 本篇博客实现的是最基础的训练过程和训练代码,所以你有很多的发挥空间;
  2. 例如:尝试使用不同的loss函数(dice loss、bce dice loss、iou loss等等)
  3. 添加数据增强操作(建议使用albumentation库,torchversion也行)
  4. 使用不同的调参策略训练模型(不同的学习率衰减策略、多阶段训练等等)
  5. 尝试使用不同的优化器训练模型等等。。。。。
  6. 等你上述尝试都做过了,你可尝试使用不同的网络,src文件夹内不断丰富不同网络结构
  7. 优化一下代码的编写,封装一下之类的,哈哈。。
  8. 总是很多实验可以做,可学习的东西也很多。。
  9. 最后,希望本篇博客能够给你带来帮助~互相学习~文章代码有不知之处多多包涵!
  • 本博客代码Github地址: https://github.com/YingXiuHe/u2net-pytorch.git/
;