Bootstrap

昇思25天学习打卡营第11天|FCN图像语义分割

FCN图像语义分割

下载Notebook          下载样例代码          查看源文件

全卷积网络(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在Fully Convolutional Networks for Semantic Segmentation[1]一文中提出的用于图像语义分割的一种框架。

FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。

fcn-1

语义分割

在具体介绍FCN之前,首先介绍何为语义分割:

图像语义分割(semantic segmentation)是图像处理和机器视觉技术中关于图像理解的重要一环,AI领域中一个重要分支,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等领域。

语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。语义在图像领域指的是图像的内容,对图片意思的理解,下图是一些语义分割的实例:

fcn-2

模型简介

FCN主要用于图像分割领域,是一种端到端的分割方法,是深度学习应用在图像语义分割的开山之作。通过进行像素级的预测直接得出与原图大小相等的label map。因FCN丢弃全连接层替换为全卷积层,网络所有层均为卷积层,故称为全卷积网络。

全卷积神经网络主要使用以下三种技术:

  1. 卷积化(Convolutional)

    使用VGG-16作为FCN的backbone。VGG-16的输入为224*224的RGB图像,输出为1000个预测值。VGG-16只能接受固定大小的输入,丢弃了空间坐标,产生非空间输出。VGG-16中共有三个全连接层,全连接层也可视为带有覆盖整个区域的卷积。将全连接层转换为卷积层能使网络输出由一维非空间输出变为二维矩阵,利用输出能生成输入图片映射的heatmap。

    fcn-3

  2. 上采样(Upsample)

    在卷积过程的卷积操作和池化操作会使得特征图的尺寸变小,为得到原图的大小的稠密图像预测,需要对得到的特征图进行上采样操作。使用双线性插值的参数来初始化上采样逆卷积的参数,后通过反向传播来学习非线性上采样。在网络中执行上采样,以通过像素损失的反向传播进行端到端的学习。

    fcn-4

  3. 跳跃结构(Skip Layer)

    利用上采样技巧对最后一层的特征图进行上采样得到原图大小的分割是步长为32像素的预测,称之为FCN-32s。由于最后一层的特征图太小,损失过多细节,采用skips结构将更具有全局信息的最后一层预测和更浅层的预测结合,使预测结果获取更多的局部细节。将底层(stride 32)的预测(FCN-32s)进行2倍的上采样得到原尺寸的图像,并与从pool4层(stride 16)进行的预测融合起来(相加),这一部分的网络被称为FCN-16s。随后将这一部分的预测再进行一次2倍的上采样并与从pool3层得到的预测融合起来,这一部分的网络被称为FCN-8s。 Skips结构将深层的全局信息与浅层的局部信息相结合。

    fcn-5

网络特点

  1. 不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。
  2. 增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。
  3. 结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。

数据处理

开始实验前,需确保本地已经安装Python环境及MindSpore。

# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: [email protected]
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"

download(url, "./dataset", kind="tar", replace=True)
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar (537.2 MB)

file_sizes: 100%|█████████████████████████████| 563M/563M [00:03<00:00, 177MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset

[2]:

'./dataset'

数据预处理

由于PASCAL VOC 2012数据集中图像的分辨率大多不一致,无法放在一个tensor中,故输入前需做标准化处理。

数据加载

将PASCAL VOC 2012数据集与SDB数据集进行混合。

import numpy as np
import cv2
import mindspore.dataset as ds

class SegDataset:
    def __init__(self,
                 image_mean,
                 image_std,
                 data_file='',
                 batch_size=32,
                 crop_size=512,
                 max_scale=2.0,
                 min_scale=0.5,
                 ignore_label=255,
                 num_classes=21,
                 num_readers=2,
                 num_parallel_calls=4):

        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale

    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset


# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)

dataset = dataset.get_dataset()

训练集可视化

运行以下代码观察载入的数据集图片(数据处理过程中已做归一化处理)。

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 8))

# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

网络构建

网络流程

FCN网络的流程如下图所示:

  1. 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
  2. 经过pool2池化,尺寸变为原始尺寸的1/4。
  3. 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
  4. 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
  5. FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
  6. FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
  7. FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。

fcn-6

使用以下代码构建FCN-8s网络。

import mindspore.nn as nn

class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096,
                      kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096,
                      kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
                                  kernel_size=1, weight_init='xavier_uniform')
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                                kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=16, stride=8, weight_init='xavier_uniform')

    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

训练准备

导入VGG-16部分预训练权重

FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。

from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)

损失函数

语义分割是对图像中每个像素点进行分类,仍是分类问题,故损失函数选择交叉熵损失函数来计算FCN网络输出与mask之间的交叉熵损失。这里我们使用的是mindspore.nn.CrossEntropyLoss()作为损失函数。

自定义评价指标 Metrics

这一部分主要对训练出来的模型效果进行评估,为了便于解释,假设如下:共有 k+1𝑘+1 个类(从 L0𝐿0 到 Lk𝐿𝑘, 其中包含一个空类或背景), pij𝑝𝑖𝑗 表示本属于i𝑖类但被预测为j𝑗类的像素数量。即, pii𝑝𝑖𝑖 表示真正的数量, 而 pijpji𝑝𝑖𝑗𝑝𝑗𝑖 则分别被解释为假正和假负, 尽管两者都是假正与假负之和。

  • Pixel Accuracy(PA, 像素精度):这是最简单的度量,为标记正确的像素占总像素的比例。

PA = \tfrac{\sum_{i=0}^{k}p_{ii}}{\sum_{i=0}^{k} \sum_{j=0}^{k}p_{ii}}

  • Mean Pixel Accuracy(MPA, 均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。

MPA = \frac{1}{k+1} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij}}

  • Mean Intersection over Union(MloU, 均交并比):为语义分割的标准度量。其计算两个集合的交集和并集之,在语义分割的问题中,这两个集合为真实值(ground truth) 和预测值(predicted segmentation)。这个比例可以变形为正真数 (intersection) 比上真正、假负、假正(并集)之和。在每个类上计算loU,之后平均。

M IoU = \frac{1}{k+1} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij} + \sum_{j=0}^{k}p_{ji} - p_{ii} }

  • Frequency Weighted Intersection over Union(FWIoU, 频权交井比):为MloU的一种提升,这种方法根据每个类出现的频率为其设置权重。

FW IoU = \frac{1}{\sum_{i=0}^{k}\sum_{j=0}^{k}p_{ij}} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij} + \sum_{j=0}^{k}p_{ji} - p_{ii} }

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train

class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy


class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy


class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou


class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou

模型训练

导入VGG-16预训练参数后,实例化损失函数、优化器,使用Model接口编译网络,训练FCN-8s网络。

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model

device_target = "GPU"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])

# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,
                               keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",
                                directory="./ckpt",
                                config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)
epoch: 1 step: 1, loss is 3.0527685
epoch: 1 step: 2, loss is 3.0048614
epoch: 1 step: 3, loss is 2.981427
epoch: 1 step: 4, loss is 2.8206217
epoch: 1 step: 5, loss is 2.6268272
epoch: 1 step: 6, loss is 2.3679323
epoch: 1 step: 7, loss is 2.4530792
epoch: 1 step: 8, loss is 1.5295638
epoch: 1 step: 9, loss is 2.4946434
epoch: 1 step: 10, loss is 1.5108395
epoch: 1 step: 11, loss is 1.9560806
epoch: 1 step: 12, loss is 1.5405258
epoch: 1 step: 13, loss is 1.1865275
epoch: 1 step: 14, loss is 1.6255631
epoch: 1 step: 15, loss is 0.92430675
epoch: 1 step: 16, loss is 0.9158874
epoch: 1 step: 17, loss is 2.5531042
epoch: 1 step: 18, loss is 1.2532341
epoch: 1 step: 19, loss is 1.4346621
epoch: 1 step: 20, loss is 0.9460135
epoch: 1 step: 21, loss is 1.2414681
epoch: 1 step: 22, loss is 2.2611444
epoch: 1 step: 23, loss is 1.6960633
epoch: 1 step: 24, loss is 2.0330138
epoch: 1 step: 25, loss is 1.4799057
epoch: 1 step: 26, loss is 1.5416082
epoch: 1 step: 27, loss is 1.0028516
epoch: 1 step: 28, loss is 1.1851429
epoch: 1 step: 29, loss is 1.4031541
epoch: 1 step: 30, loss is 1.1314307
epoch: 1 step: 31, loss is 1.8798188
epoch: 1 step: 32, loss is 1.4333348
epoch: 1 step: 33, loss is 0.7920459
epoch: 1 step: 34, loss is 2.2017825
epoch: 1 step: 35, loss is 1.1045558
epoch: 1 step: 36, loss is 1.0527202
epoch: 1 step: 37, loss is 2.361278
epoch: 1 step: 38, loss is 2.2758758
epoch: 1 step: 39, loss is 2.708242
epoch: 1 step: 40, loss is 1.2672068
epoch: 1 step: 41, loss is 1.6704427
epoch: 1 step: 42, loss is 1.8233486
epoch: 1 step: 43, loss is 1.7580253
epoch: 1 step: 44, loss is 1.7874086
epoch: 1 step: 45, loss is 1.5946689
epoch: 1 step: 46, loss is 1.9213799
epoch: 1 step: 47, loss is 1.4156342
epoch: 1 step: 48, loss is 0.9919097
epoch: 1 step: 49, loss is 1.1712832
epoch: 1 step: 50, loss is 1.386591
epoch: 1 step: 51, loss is 1.1723616
epoch: 1 step: 52, loss is 0.895816
epoch: 1 step: 53, loss is 1.3006569
epoch: 1 step: 54, loss is 1.172828
epoch: 1 step: 55, loss is 0.84370124
epoch: 1 step: 56, loss is 1.5304729
epoch: 1 step: 57, loss is 1.6243601
epoch: 1 step: 58, loss is 1.6902919
epoch: 1 step: 59, loss is 2.1011014
epoch: 1 step: 60, loss is 1.6675043
epoch: 1 step: 61, loss is 2.6059213
epoch: 1 step: 62, loss is 2.0156097
epoch: 1 step: 63, loss is 1.6353211
epoch: 1 step: 64, loss is 1.165773
epoch: 1 step: 65, loss is 1.1785283
epoch: 1 step: 66, loss is 1.238415
epoch: 1 step: 67, loss is 0.77280235
epoch: 1 step: 68, loss is 1.5867974
epoch: 1 step: 69, loss is 1.6434702
epoch: 1 step: 70, loss is 1.0781606
epoch: 1 step: 71, loss is 2.586124
epoch: 1 step: 72, loss is 1.2035401
epoch: 1 step: 73, loss is 1.2931656
epoch: 1 step: 74, loss is 1.3493009
epoch: 1 step: 75, loss is 1.832552
epoch: 1 step: 76, loss is 1.4155468
epoch: 1 step: 77, loss is 1.3250738
epoch: 1 step: 78, loss is 0.93689483
epoch: 1 step: 79, loss is 0.8123223
epoch: 1 step: 80, loss is 2.6534646
epoch: 1 step: 81, loss is 0.96996516
epoch: 1 step: 82, loss is 2.1042116
epoch: 1 step: 83, loss is 1.1147116
epoch: 1 step: 84, loss is 1.589953
epoch: 1 step: 85, loss is 1.3308922
epoch: 1 step: 86, loss is 0.9746961
epoch: 1 step: 87, loss is 1.511168
epoch: 1 step: 88, loss is 2.2226815
epoch: 1 step: 89, loss is 1.5064403
epoch: 1 step: 90, loss is 1.1910138
epoch: 1 step: 91, loss is 1.3216316
epoch: 1 step: 92, loss is 1.3255843
epoch: 1 step: 93, loss is 1.2393374
epoch: 1 step: 94, loss is 1.0219425
epoch: 1 step: 95, loss is 0.8613696
epoch: 1 step: 96, loss is 1.5482918
epoch: 1 step: 97, loss is 1.5028028
epoch: 1 step: 98, loss is 2.5139964
epoch: 1 step: 99, loss is 1.9670683
epoch: 1 step: 100, loss is 3.7016795
epoch: 1 step: 101, loss is 1.3040333
epoch: 1 step: 102, loss is 1.6422337
epoch: 1 step: 103, loss is 2.5483966
epoch: 1 step: 104, loss is 1.7969916
epoch: 1 step: 105, loss is 1.8140471
epoch: 1 step: 106, loss is 1.9153649
epoch: 1 step: 107, loss is 1.7879149
epoch: 1 step: 108, loss is 2.0165818
epoch: 1 step: 109, loss is 1.6542164
epoch: 1 step: 110, loss is 1.0112859
epoch: 1 step: 111, loss is 1.2590146
epoch: 1 step: 112, loss is 1.3340175
epoch: 1 step: 113, loss is 1.7140545
epoch: 1 step: 114, loss is 1.0996683
epoch: 1 step: 115, loss is 1.5423781
epoch: 1 step: 116, loss is 1.6943505
epoch: 1 step: 117, loss is 1.581251
epoch: 1 step: 118, loss is 1.1688832
epoch: 1 step: 119, loss is 1.0259211
epoch: 1 step: 120, loss is 1.5394892
epoch: 1 step: 130, loss is 1.4948845
epoch: 1 step: 131, loss is 2.1155372
epoch: 1 step: 132, loss is 0.9271115
epoch: 1 step: 133, loss is 1.6069859
epoch: 1 step: 134, loss is 2.1225526
epoch: 1 step: 135, loss is 1.0335933
epoch: 1 step: 136, loss is 1.4472002
epoch: 1 step: 137, loss is 1.5058426
epoch: 1 step: 138, loss is 1.0474901
epoch: 1 step: 139, loss is 1.0177362
epoch: 1 step: 140, loss is 1.6175524
epoch: 1 step: 141, loss is 2.1220052
epoch: 1 step: 142, loss is 1.694429
epoch: 1 step: 143, loss is 1.8369793
epoch: 1 step: 144, loss is 1.3995504
epoch: 1 step: 145, loss is 1.7555605
epoch: 1 step: 146, loss is 1.9345758
epoch: 1 step: 147, loss is 1.0972852
epoch: 1 step: 148, loss is 2.3470957
epoch: 1 step: 149, loss is 1.2701917
epoch: 1 step: 150, loss is 1.4058386
epoch: 1 step: 151, loss is 1.2972242
epoch: 1 step: 152, loss is 1.3441492
epoch: 1 step: 153, loss is 1.563546
epoch: 1 step: 154, loss is 1.0316793
epoch: 1 step: 155, loss is 0.85918385
epoch: 1 step: 156, loss is 1.1022626
epoch: 1 step: 157, loss is 1.3310162
epoch: 1 step: 158, loss is 2.2302642
epoch: 1 step: 159, loss is 1.1926919
epoch: 1 step: 160, loss is 0.99002105
epoch: 1 step: 161, loss is 1.2781047
epoch: 1 step: 162, loss is 0.9535142
epoch: 1 step: 163, loss is 1.451594
epoch: 1 step: 164, loss is 1.1031244
epoch: 1 step: 165, loss is 0.9925471
epoch: 1 step: 166, loss is 1.200205
epoch: 1 step: 167, loss is 2.1479788
epoch: 1 step: 168, loss is 1.4817247
epoch: 1 step: 169, loss is 1.7980807
epoch: 1 step: 170, loss is 2.4411955
epoch: 1 step: 171, loss is 1.5032545
epoch: 1 step: 172, loss is 1.4797682
epoch: 1 step: 173, loss is 1.548252
epoch: 1 step: 174, loss is 1.492634
epoch: 1 step: 175, loss is 1.4916027
epoch: 1 step: 176, loss is 1.6407192
epoch: 1 step: 177, loss is 1.0971035
epoch: 1 step: 178, loss is 1.2219903
epoch: 1 step: 179, loss is 1.7273251
epoch: 1 step: 180, loss is 1.6422912
epoch: 1 step: 181, loss is 0.9990282
epoch: 1 step: 182, loss is 2.3068888
epoch: 1 step: 183, loss is 1.9660761
epoch: 1 step: 184, loss is 1.7996448
epoch: 1 step: 185, loss is 1.6307634
epoch: 1 step: 186, loss is 1.3286101
epoch: 1 step: 187, loss is 1.2952386
epoch: 1 step: 188, loss is 1.9838817
epoch: 1 step: 189, loss is 1.4275842
epoch: 1 step: 190, loss is 1.4513435
epoch: 1 step: 191, loss is 1.3864361
epoch: 1 step: 192, loss is 1.1794121
epoch: 1 step: 193, loss is 1.4535242
epoch: 1 step: 194, loss is 1.6198533
epoch: 1 step: 195, loss is 1.7515683
epoch: 1 step: 196, loss is 2.2098973
epoch: 1 step: 197, loss is 0.80396754
epoch: 1 step: 198, loss is 1.2999126
epoch: 1 step: 199, loss is 1.0943366
epoch: 1 step: 200, loss is 1.3179741
epoch: 1 step: 201, loss is 1.0918612
epoch: 1 step: 202, loss is 1.2819842
epoch: 1 step: 203, loss is 1.7914909
epoch: 1 step: 204, loss is 0.998963
epoch: 1 step: 205, loss is 2.5051603
epoch: 1 step: 206, loss is 1.8936038
epoch: 1 step: 207, loss is 2.4194045
epoch: 1 step: 208, loss is 1.1255649
epoch: 1 step: 209, loss is 1.3642749
epoch: 1 step: 210, loss is 1.4351572
epoch: 1 step: 211, loss is 1.6005664
epoch: 1 step: 212, loss is 1.4872969
epoch: 1 step: 213, loss is 1.4303253
epoch: 1 step: 214, loss is 1.7317516
epoch: 1 step: 215, loss is 2.413413
epoch: 1 step: 216, loss is 1.2426609
epoch: 1 step: 217, loss is 2.4082165
epoch: 1 step: 218, loss is 1.884935
epoch: 1 step: 219, loss is 1.5667516
epoch: 1 step: 220, loss is 2.344247
epoch: 1 step: 221, loss is 1.3106698
epoch: 1 step: 222, loss is 1.1881089
epoch: 1 step: 223, loss is 1.0340558
epoch: 1 step: 224, loss is 1.235635
epoch: 1 step: 225, loss is 0.7287971
epoch: 1 step: 226, loss is 1.2853668
epoch: 1 step: 227, loss is 1.0981282
epoch: 1 step: 228, loss is 1.713914
epoch: 1 step: 229, loss is 1.1456641
epoch: 1 step: 230, loss is 1.6192797
epoch: 1 step: 231, loss is 1.0149568
epoch: 1 step: 232, loss is 1.1124762
epoch: 1 step: 233, loss is 1.6938072
epoch: 1 step: 234, loss is 1.2327411
epoch: 1 step: 235, loss is 1.3397502
epoch: 1 step: 236, loss is 1.2113299
epoch: 1 step: 237, loss is 1.6674987
epoch: 1 step: 238, loss is 2.0675793
epoch: 1 step: 239, loss is 3.0707836
epoch: 1 step: 240, loss is 0.9131917
epoch: 1 step: 241, loss is 1.9954805
epoch: 1 step: 242, loss is 1.0335884
epoch: 1 step: 243, loss is 1.3639531
epoch: 1 step: 244, loss is 2.2480338
epoch: 1 step: 245, loss is 1.6686658
epoch: 1 step: 246, loss is 1.5465958
epoch: 1 step: 247, loss is 1.9616789
epoch: 1 step: 248, loss is 1.5423313
epoch: 1 step: 249, loss is 1.4921515
epoch: 1 step: 250, loss is 1.617318
epoch: 1 step: 251, loss is 1.3719391
epoch: 1 step: 252, loss is 1.2122617
epoch: 1 step: 253, loss is 1.7977221
epoch: 1 step: 254, loss is 2.1493404
epoch: 1 step: 255, loss is 1.2188464
epoch: 1 step: 256, loss is 0.93655884
epoch: 1 step: 257, loss is 2.113397
epoch: 1 step: 258, loss is 1.4956061
epoch: 1 step: 259, loss is 1.3289506
epoch: 1 step: 260, loss is 1.4055485
epoch: 1 step: 261, loss is 1.2605637
epoch: 1 step: 262, loss is 0.9798568
epoch: 1 step: 263, loss is 1.1774526
epoch: 1 step: 264, loss is 1.4668372
epoch: 1 step: 265, loss is 0.8725292
epoch: 1 step: 266, loss is 1.2486666
epoch: 1 step: 267, loss is 2.1367614
epoch: 1 step: 268, loss is 1.5084226
epoch: 1 step: 269, loss is 0.83611643
epoch: 1 step: 270, loss is 1.1384557
epoch: 1 step: 271, loss is 1.4288726
epoch: 1 step: 272, loss is 1.5293155
epoch: 1 step: 273, loss is 1.2773206
epoch: 1 step: 274, loss is 0.9502489
epoch: 1 step: 275, loss is 1.5733099
epoch: 1 step: 276, loss is 0.8641495
epoch: 1 step: 277, loss is 1.4900316
epoch: 1 step: 278, loss is 1.8337973
epoch: 1 step: 279, loss is 1.5807074
epoch: 1 step: 280, loss is 1.8525515
epoch: 1 step: 281, loss is 1.3411417
epoch: 1 step: 282, loss is 1.2847694
epoch: 1 step: 283, loss is 1.238222
epoch: 1 step: 284, loss is 3.044682
epoch: 1 step: 285, loss is 1.1259762
epoch: 1 step: 286, loss is 1.0778615
epoch: 1 step: 287, loss is 0.8780496
epoch: 1 step: 288, loss is 2.4371555
epoch: 1 step: 289, loss is 1.0233065
epoch: 1 step: 290, loss is 2.2943728
epoch: 1 step: 291, loss is 1.4941685
epoch: 1 step: 292, loss is 0.9715195
epoch: 1 step: 293, loss is 0.7518496
epoch: 1 step: 294, loss is 0.6744986
epoch: 1 step: 295, loss is 1.3522893
epoch: 1 step: 296, loss is 1.1049654
epoch: 1 step: 297, loss is 1.6842984
epoch: 1 step: 298, loss is 1.747368
epoch: 1 step: 299, loss is 1.3681042
epoch: 1 step: 300, loss is 1.0456749
epoch: 1 step: 301, loss is 1.0972849
epoch: 1 step: 302, loss is 1.3504981
epoch: 1 step: 303, loss is 1.5224172
epoch: 1 step: 304, loss is 1.8395495
epoch: 1 step: 305, loss is 1.9393369
epoch: 1 step: 306, loss is 2.546815
epoch: 1 step: 307, loss is 1.8627747
epoch: 1 step: 308, loss is 1.1619184
epoch: 1 step: 309, loss is 1.8338776
epoch: 1 step: 310, loss is 1.0259397
epoch: 1 step: 311, loss is 1.3905746
epoch: 1 step: 312, loss is 0.9005772
epoch: 1 step: 313, loss is 1.1746209
epoch: 1 step: 314, loss is 1.138915
epoch: 1 step: 315, loss is 1.4782747
epoch: 1 step: 316, loss is 1.3956249
epoch: 1 step: 317, loss is 1.960855
epoch: 1 step: 318, loss is 1.9038904
epoch: 1 step: 319, loss is 1.7648991
epoch: 1 step: 320, loss is 1.1020143
epoch: 1 step: 321, loss is 1.6706741
epoch: 1 step: 322, loss is 0.9224906
epoch: 1 step: 323, loss is 1.7607471
epoch: 1 step: 324, loss is 1.3373989
epoch: 1 step: 325, loss is 1.6423498
epoch: 1 step: 326, loss is 1.71259
epoch: 1 step: 327, loss is 1.6872907
epoch: 1 step: 328, loss is 0.9707066
epoch: 1 step: 329, loss is 0.785591
epoch: 1 step: 330, loss is 1.7150952
epoch: 1 step: 331, loss is 1.4802103
epoch: 1 step: 332, loss is 1.1754842
epoch: 1 step: 333, loss is 2.991051
epoch: 1 step: 334, loss is 1.4276946
epoch: 1 step: 335, loss is 1.2097863
epoch: 1 step: 336, loss is 0.9949289
epoch: 1 step: 337, loss is 0.96035004
epoch: 1 step: 338, loss is 1.2723651
epoch: 1 step: 339, loss is 1.856986
epoch: 1 step: 340, loss is 1.5827174
epoch: 1 step: 341, loss is 2.2494535
epoch: 1 step: 342, loss is 0.71173596
epoch: 1 step: 343, loss is 0.8062207
epoch: 1 step: 344, loss is 1.1436613
epoch: 1 step: 345, loss is 1.5166568
epoch: 1 step: 346, loss is 2.0243762
epoch: 1 step: 347, loss is 2.0037735
epoch: 1 step: 348, loss is 1.6592412
epoch: 1 step: 349, loss is 1.579919
epoch: 1 step: 350, loss is 1.0694847
epoch: 1 step: 351, loss is 1.5947387
epoch: 1 step: 352, loss is 2.2314963
epoch: 1 step: 353, loss is 1.1922843
epoch: 1 step: 354, loss is 1.0600113
epoch: 1 step: 355, loss is 1.2390709
epoch: 1 step: 356, loss is 1.5090889
epoch: 1 step: 357, loss is 0.85195154
epoch: 1 step: 358, loss is 0.72936165
epoch: 1 step: 359, loss is 0.90514135
epoch: 1 step: 360, loss is 3.1455302
epoch: 1 step: 361, loss is 1.7499449
epoch: 1 step: 362, loss is 1.6367987
epoch: 1 step: 363, loss is 1.3441752
epoch: 1 step: 364, loss is 0.87220424
epoch: 1 step: 365, loss is 2.0697465
epoch: 1 step: 366, loss is 1.2037894
epoch: 1 step: 367, loss is 1.1617497
epoch: 1 step: 368, loss is 0.6254088
epoch: 1 step: 369, loss is 2.2478728
epoch: 1 step: 370, loss is 1.5273347
epoch: 1 step: 371, loss is 1.8316954
epoch: 1 step: 372, loss is 1.5271983
epoch: 1 step: 373, loss is 1.1279682
epoch: 1 step: 374, loss is 1.5618366
epoch: 1 step: 375, loss is 0.88354015
epoch: 1 step: 376, loss is 1.3197343
epoch: 1 step: 377, loss is 1.4612873
epoch: 1 step: 378, loss is 1.6467689
epoch: 1 step: 379, loss is 1.2329705
epoch: 1 step: 380, loss is 1.0620636
epoch: 1 step: 381, loss is 1.767083
epoch: 1 step: 382, loss is 1.6182848
epoch: 1 step: 383, loss is 1.3306346
epoch: 1 step: 384, loss is 1.2667534
epoch: 1 step: 385, loss is 2.621318
epoch: 1 step: 386, loss is 1.423682
epoch: 1 step: 387, loss is 1.8640695
epoch: 1 step: 388, loss is 1.907525
epoch: 1 step: 389, loss is 2.0283382
epoch: 1 step: 390, loss is 1.171629
epoch: 1 step: 391, loss is 2.3802164
epoch: 1 step: 392, loss is 2.1991978
epoch: 1 step: 393, loss is 2.2855136
epoch: 1 step: 394, loss is 1.797338
epoch: 1 step: 395, loss is 1.8693988
epoch: 1 step: 396, loss is 1.2832038
epoch: 1 step: 397, loss is 1.6201535
epoch: 1 step: 398, loss is 1.1078694
epoch: 1 step: 399, loss is 1.5675877
epoch: 1 step: 400, loss is 1.3739157
epoch: 1 step: 401, loss is 1.1140991
epoch: 1 step: 402, loss is 0.9442382
epoch: 1 step: 403, loss is 1.9028946
epoch: 1 step: 404, loss is 1.5562667
epoch: 1 step: 405, loss is 1.4056962
epoch: 1 step: 406, loss is 2.0712404
epoch: 1 step: 407, loss is 1.4427837
epoch: 1 step: 408, loss is 1.1797237
epoch: 1 step: 409, loss is 1.0245701
epoch: 1 step: 410, loss is 1.695865
epoch: 1 step: 411, loss is 1.1650726
epoch: 1 step: 412, loss is 3.354278
epoch: 1 step: 413, loss is 1.1731795
epoch: 1 step: 414, loss is 1.8880397
epoch: 1 step: 415, loss is 1.016161
epoch: 1 step: 416, loss is 1.7765203
epoch: 1 step: 417, loss is 1.3347396
epoch: 1 step: 418, loss is 1.7748733
epoch: 1 step: 419, loss is 2.2146857
epoch: 1 step: 420, loss is 1.7701968
epoch: 1 step: 421, loss is 0.8550375
epoch: 1 step: 422, loss is 1.1587888
epoch: 1 step: 423, loss is 2.6597574
epoch: 1 step: 424, loss is 0.7924839
epoch: 1 step: 425, loss is 1.8438962
epoch: 1 step: 426, loss is 1.6584332
epoch: 1 step: 427, loss is 0.9588195
epoch: 1 step: 428, loss is 1.9193221
epoch: 1 step: 429, loss is 2.1664352
epoch: 1 step: 430, loss is 2.022101
epoch: 1 step: 431, loss is 1.3965279
epoch: 1 step: 432, loss is 2.573004
epoch: 1 step: 433, loss is 1.4432236
epoch: 1 step: 434, loss is 1.2038684
epoch: 1 step: 435, loss is 1.3429259
epoch: 1 step: 436, loss is 1.5304085
epoch: 1 step: 437, loss is 1.3248252
epoch: 1 step: 438, loss is 1.9808514
epoch: 1 step: 439, loss is 1.7648969
epoch: 1 step: 440, loss is 1.6836438
epoch: 1 step: 441, loss is 1.3854442
epoch: 1 step: 442, loss is 1.3696083
epoch: 1 step: 443, loss is 1.4548452
epoch: 1 step: 444, loss is 1.7775606
epoch: 1 step: 445, loss is 2.1199002
epoch: 1 step: 446, loss is 1.7166098
epoch: 1 step: 447, loss is 1.7135342
epoch: 1 step: 448, loss is 1.4478246
epoch: 1 step: 449, loss is 1.7743763
epoch: 1 step: 450, loss is 1.3946937
epoch: 1 step: 451, loss is 1.4139277
epoch: 1 step: 452, loss is 3.0477812
epoch: 1 step: 453, loss is 1.8283049
epoch: 1 step: 454, loss is 1.5281527
epoch: 1 step: 455, loss is 1.0578873
epoch: 1 step: 456, loss is 1.3454359
epoch: 1 step: 457, loss is 1.7882048
epoch: 1 step: 458, loss is 1.487412
epoch: 1 step: 459, loss is 1.2968873
epoch: 1 step: 460, loss is 1.1950387
epoch: 1 step: 461, loss is 0.84163964
epoch: 1 step: 462, loss is 1.3063484
epoch: 1 step: 463, loss is 1.314682
epoch: 1 step: 464, loss is 1.3117795
epoch: 1 step: 465, loss is 1.7155809
epoch: 1 step: 466, loss is 2.0095487
epoch: 1 step: 467, loss is 0.99664265
epoch: 1 step: 468, loss is 1.5353621
epoch: 1 step: 469, loss is 1.5310316
epoch: 1 step: 470, loss is 1.2093842
epoch: 1 step: 471, loss is 1.171584
epoch: 1 step: 472, loss is 1.3116584
epoch: 1 step: 473, loss is 1.805658
epoch: 1 step: 474, loss is 1.208771
epoch: 1 step: 475, loss is 0.84819686
epoch: 1 step: 476, loss is 1.5831177
epoch: 1 step: 477, loss is 1.6509881
epoch: 1 step: 478, loss is 1.0721017
epoch: 1 step: 479, loss is 1.6479609
epoch: 1 step: 480, loss is 1.909918
epoch: 1 step: 481, loss is 1.2556167
epoch: 1 step: 482, loss is 1.6806097
epoch: 1 step: 483, loss is 1.5273694
epoch: 1 step: 484, loss is 1.2249551
epoch: 1 step: 485, loss is 1.4319501
epoch: 1 step: 486, loss is 2.2023208
epoch: 1 step: 487, loss is 1.8027079
epoch: 1 step: 488, loss is 1.7778509
epoch: 1 step: 489, loss is 1.6435654
epoch: 1 step: 490, loss is 1.2180468
epoch: 1 step: 491, loss is 1.4006964
epoch: 1 step: 492, loss is 0.8433892
epoch: 1 step: 493, loss is 1.7483639
epoch: 1 step: 494, loss is 1.4196682
epoch: 1 step: 495, loss is 2.1997652
epoch: 1 step: 496, loss is 1.8009301
epoch: 1 step: 497, loss is 1.2662504
epoch: 1 step: 498, loss is 1.3122671
epoch: 1 step: 499, loss is 0.94252497
epoch: 1 step: 500, loss is 1.5988183
epoch: 1 step: 501, loss is 0.81842417
epoch: 1 step: 502, loss is 2.5279312
epoch: 1 step: 503, loss is 2.0388105
epoch: 1 step: 504, loss is 1.5082455
epoch: 1 step: 505, loss is 2.2559474
epoch: 1 step: 506, loss is 0.9857316
epoch: 1 step: 507, loss is 1.300087
epoch: 1 step: 508, loss is 1.3187114
epoch: 1 step: 509, loss is 2.1582053
epoch: 1 step: 510, loss is 1.7774242
epoch: 1 step: 511, loss is 1.5874071
epoch: 1 step: 512, loss is 1.2758147
epoch: 1 step: 513, loss is 1.1309856
epoch: 1 step: 514, loss is 1.520126
epoch: 1 step: 515, loss is 2.0584934
epoch: 1 step: 516, loss is 0.7208111
epoch: 1 step: 517, loss is 1.3925847
epoch: 1 step: 518, loss is 1.2815424
epoch: 1 step: 519, loss is 1.8403438
epoch: 1 step: 520, loss is 1.8578997
epoch: 1 step: 521, loss is 0.93138105
epoch: 1 step: 522, loss is 1.5649883
epoch: 1 step: 523, loss is 1.1465777
epoch: 1 step: 524, loss is 1.2231576
epoch: 1 step: 525, loss is 0.8834184
epoch: 1 step: 526, loss is 2.5539055
epoch: 1 step: 527, loss is 2.294355
epoch: 1 step: 528, loss is 2.5085661
epoch: 1 step: 529, loss is 1.0882989
epoch: 1 step: 530, loss is 1.5068724
epoch: 1 step: 531, loss is 1.20243
epoch: 1 step: 532, loss is 1.1452907
epoch: 1 step: 533, loss is 1.4972626
epoch: 1 step: 534, loss is 1.5495008
epoch: 1 step: 535, loss is 0.95804363
epoch: 1 step: 536, loss is 1.7470431
epoch: 1 step: 537, loss is 0.9330883
epoch: 1 step: 538, loss is 2.0415454
epoch: 1 step: 539, loss is 2.044959
epoch: 1 step: 540, loss is 1.5183104
epoch: 1 step: 541, loss is 1.5978185
epoch: 1 step: 542, loss is 1.8003638
epoch: 1 step: 543, loss is 1.1461648
epoch: 1 step: 544, loss is 1.0884062
epoch: 1 step: 545, loss is 1.1786237
epoch: 1 step: 546, loss is 1.7931774
epoch: 1 step: 547, loss is 1.3856888
epoch: 1 step: 548, loss is 2.7594001
epoch: 1 step: 549, loss is 0.75833553
epoch: 1 step: 550, loss is 2.3298347
epoch: 1 step: 551, loss is 0.92912614
epoch: 1 step: 552, loss is 1.4362853
epoch: 1 step: 553, loss is 1.3880788
epoch: 1 step: 554, loss is 2.1666338
epoch: 1 step: 555, loss is 1.2242991
epoch: 1 step: 556, loss is 0.7822968
epoch: 1 step: 557, loss is 1.4965957
epoch: 1 step: 558, loss is 1.3685954
epoch: 1 step: 559, loss is 1.1583418
epoch: 1 step: 560, loss is 0.9078367
epoch: 1 step: 561, loss is 2.1545403
epoch: 1 step: 562, loss is 0.65759337
epoch: 1 step: 563, loss is 1.940956
epoch: 1 step: 564, loss is 1.6454786
epoch: 1 step: 565, loss is 1.1247386
epoch: 1 step: 566, loss is 2.538431
epoch: 1 step: 567, loss is 1.253809
epoch: 1 step: 568, loss is 1.6495032
epoch: 1 step: 569, loss is 1.4771353
epoch: 1 step: 570, loss is 2.1122513
epoch: 1 step: 571, loss is 1.2539989
epoch: 1 step: 572, loss is 1.5870638
epoch: 1 step: 573, loss is 1.176675
epoch: 1 step: 574, loss is 1.9548633
epoch: 1 step: 575, loss is 1.108203
epoch: 1 step: 576, loss is 0.8844928
epoch: 1 step: 577, loss is 2.1868367
epoch: 1 step: 578, loss is 0.5840582
epoch: 1 step: 579, loss is 1.3547535
epoch: 1 step: 580, loss is 1.2667708
epoch: 1 step: 581, loss is 1.0079268
epoch: 1 step: 582, loss is 1.620551
epoch: 1 step: 583, loss is 1.3138187
epoch: 1 step: 584, loss is 1.4801897
epoch: 1 step: 585, loss is 1.6354014
epoch: 1 step: 586, loss is 1.5778404
epoch: 1 step: 587, loss is 1.5840399
epoch: 1 step: 588, loss is 0.9430897
epoch: 1 step: 589, loss is 1.1547247
epoch: 1 step: 590, loss is 1.34418
epoch: 1 step: 591, loss is 1.6498495
epoch: 1 step: 592, loss is 1.2777749
epoch: 1 step: 593, loss is 1.4952872
epoch: 1 step: 594, loss is 1.351592
epoch: 1 step: 595, loss is 0.9463069
epoch: 1 step: 596, loss is 1.2557908
epoch: 1 step: 597, loss is 1.7019539
epoch: 1 step: 598, loss is 2.0353796
epoch: 1 step: 599, loss is 1.269087
epoch: 1 step: 600, loss is 1.2966219
epoch: 1 step: 601, loss is 1.5941294
epoch: 1 step: 602, loss is 0.7835074
epoch: 1 step: 603, loss is 1.5396837
epoch: 1 step: 604, loss is 1.447866
epoch: 1 step: 605, loss is 1.3904067
epoch: 1 step: 606, loss is 1.2686344
epoch: 1 step: 607, loss is 1.275218
epoch: 1 step: 608, loss is 1.2919626
epoch: 1 step: 609, loss is 0.8226384
epoch: 1 step: 610, loss is 1.2967952
epoch: 1 step: 611, loss is 2.2100737
epoch: 1 step: 612, loss is 1.6249694
epoch: 1 step: 613, loss is 0.655721
epoch: 1 step: 614, loss is 0.98944175
epoch: 1 step: 615, loss is 1.5826595
epoch: 1 step: 616, loss is 1.8792266
epoch: 1 step: 617, loss is 1.4592211
epoch: 1 step: 618, loss is 1.3171912
epoch: 1 step: 619, loss is 1.4005269
epoch: 1 step: 620, loss is 1.3942231
epoch: 1 step: 621, loss is 2.1880786
epoch: 1 step: 622, loss is 1.5398805
epoch: 1 step: 623, loss is 1.2639322
epoch: 1 step: 624, loss is 1.0633816
epoch: 1 step: 625, loss is 1.261791
epoch: 1 step: 626, loss is 1.5193949
epoch: 1 step: 627, loss is 1.6693004
epoch: 1 step: 628, loss is 1.0621362
epoch: 1 step: 629, loss is 1.0583364
epoch: 1 step: 630, loss is 0.9883682
epoch: 1 step: 631, loss is 1.2693422
epoch: 1 step: 632, loss is 2.1782682
epoch: 1 step: 633, loss is 0.9600427
epoch: 1 step: 634, loss is 0.50657713
epoch: 1 step: 635, loss is 1.0079299
epoch: 1 step: 636, loss is 2.1642375
epoch: 1 step: 637, loss is 1.9791985
epoch: 1 step: 638, loss is 1.2060435
epoch: 1 step: 639, loss is 1.2873571
epoch: 1 step: 640, loss is 1.8583806
epoch: 1 step: 641, loss is 2.042755
epoch: 1 step: 642, loss is 1.3595409
epoch: 1 step: 643, loss is 1.7696285
epoch: 1 step: 644, loss is 1.2569897
epoch: 1 step: 645, loss is 1.2484068
epoch: 1 step: 646, loss is 1.429846
epoch: 1 step: 647, loss is 1.1474854
epoch: 1 step: 648, loss is 1.2996947
epoch: 1 step: 649, loss is 0.9709231
epoch: 1 step: 650, loss is 0.5451036
epoch: 1 step: 651, loss is 1.9795272
epoch: 1 step: 652, loss is 0.96661735
epoch: 1 step: 653, loss is 1.3124712
epoch: 1 step: 654, loss is 1.4434642
epoch: 1 step: 655, loss is 2.3790784
epoch: 1 step: 656, loss is 1.3500826
epoch: 1 step: 657, loss is 1.8548357
epoch: 1 step: 658, loss is 1.2318676
epoch: 1 step: 659, loss is 1.3553193
epoch: 1 step: 660, loss is 1.233791
epoch: 1 step: 661, loss is 1.4928213
epoch: 1 step: 662, loss is 1.6919823
epoch: 1 step: 663, loss is 3.16771
epoch: 1 step: 664, loss is 1.3294442
epoch: 1 step: 665, loss is 1.2469656
epoch: 1 step: 666, loss is 1.3794878
epoch: 1 step: 667, loss is 1.340612
epoch: 1 step: 668, loss is 1.4079144
epoch: 1 step: 669, loss is 1.5904312
epoch: 1 step: 670, loss is 1.9347829
epoch: 1 step: 671, loss is 1.3055848
epoch: 1 step: 672, loss is 1.7020828
epoch: 1 step: 673, loss is 1.1483328
epoch: 1 step: 674, loss is 1.4923999
epoch: 1 step: 675, loss is 0.6389783
epoch: 1 step: 676, loss is 1.2927204
epoch: 1 step: 677, loss is 1.8150868
epoch: 1 step: 678, loss is 1.8925722
epoch: 1 step: 679, loss is 1.3367504
epoch: 1 step: 680, loss is 1.3212615
epoch: 1 step: 681, loss is 1.5536617
epoch: 1 step: 682, loss is 1.2940904
epoch: 1 step: 683, loss is 1.4109453
epoch: 1 step: 684, loss is 0.6979305
epoch: 1 step: 685, loss is 1.5226032
epoch: 1 step: 686, loss is 2.1190503
epoch: 1 step: 687, loss is 2.5276046
epoch: 1 step: 688, loss is 1.2090523
epoch: 1 step: 689, loss is 1.8066235
epoch: 1 step: 690, loss is 2.919229
epoch: 1 step: 691, loss is 1.3888588
epoch: 1 step: 692, loss is 1.2613399
epoch: 1 step: 693, loss is 1.7037213
epoch: 1 step: 694, loss is 1.7586114
epoch: 1 step: 695, loss is 1.0297754
epoch: 1 step: 696, loss is 1.1850313
epoch: 1 step: 697, loss is 2.1144323
epoch: 1 step: 698, loss is 0.9241776
epoch: 1 step: 699, loss is 1.789486
epoch: 1 step: 700, loss is 1.935228
epoch: 1 step: 701, loss is 1.8320175
epoch: 1 step: 702, loss is 1.2426604
epoch: 1 step: 703, loss is 1.2459911
epoch: 1 step: 704, loss is 1.5168828
epoch: 1 step: 705, loss is 1.2350384
epoch: 1 step: 706, loss is 1.4747086
epoch: 1 step: 707, loss is 2.0145085
epoch: 1 step: 708, loss is 1.1626722
epoch: 1 step: 709, loss is 1.7482274
epoch: 1 step: 710, loss is 1.324802
epoch: 1 step: 711, loss is 1.6273341
epoch: 1 step: 712, loss is 2.5547557
epoch: 1 step: 713, loss is 2.349625
epoch: 1 step: 714, loss is 1.4811035
epoch: 1 step: 715, loss is 1.366608
epoch: 1 step: 716, loss is 1.786861
epoch: 1 step: 717, loss is 1.4309088
epoch: 1 step: 718, loss is 1.1622107
epoch: 1 step: 719, loss is 1.9939502
epoch: 1 step: 720, loss is 0.9484589
epoch: 1 step: 721, loss is 1.2694954
epoch: 1 step: 722, loss is 1.7443618
epoch: 1 step: 723, loss is 0.8755427
epoch: 1 step: 724, loss is 1.3250761
epoch: 1 step: 725, loss is 1.8339058
epoch: 1 step: 726, loss is 1.1553396
epoch: 1 step: 727, loss is 1.0966207
epoch: 1 step: 728, loss is 1.756647
epoch: 1 step: 729, loss is 1.7920381
epoch: 1 step: 730, loss is 0.9646638
epoch: 1 step: 731, loss is 1.4262266
epoch: 1 step: 732, loss is 1.0389134
epoch: 1 step: 733, loss is 1.1611218
epoch: 1 step: 734, loss is 0.6411306
epoch: 1 step: 735, loss is 2.232941
epoch: 1 step: 736, loss is 0.73005974
epoch: 1 step: 737, loss is 3.365806
epoch: 1 step: 738, loss is 1.2195709
epoch: 1 step: 739, loss is 1.4329787
epoch: 1 step: 740, loss is 2.052974
epoch: 1 step: 741, loss is 1.9053595
epoch: 1 step: 742, loss is 2.1564622
epoch: 1 step: 743, loss is 1.776541
epoch: 1 step: 744, loss is 1.7492744
epoch: 1 step: 745, loss is 1.3150692
epoch: 1 step: 746, loss is 1.2099519
epoch: 1 step: 747, loss is 1.3316267
epoch: 1 step: 748, loss is 1.4587404
epoch: 1 step: 749, loss is 1.217104
epoch: 1 step: 750, loss is 1.8303763
epoch: 1 step: 751, loss is 0.9766423
epoch: 1 step: 752, loss is 1.0829071
epoch: 1 step: 753, loss is 2.1348143
epoch: 1 step: 754, loss is 1.1527518
epoch: 1 step: 755, loss is 1.7305229
epoch: 1 step: 756, loss is 2.123969
epoch: 1 step: 757, loss is 1.1837857
epoch: 1 step: 758, loss is 1.7110847
epoch: 1 step: 759, loss is 1.4304799
epoch: 1 step: 760, loss is 1.6833315
epoch: 1 step: 761, loss is 3.0664809
epoch: 1 step: 762, loss is 1.5151067
epoch: 1 step: 763, loss is 1.5804203
epoch: 1 step: 764, loss is 1.7286215
epoch: 1 step: 765, loss is 2.1113994
epoch: 1 step: 766, loss is 1.8464983
epoch: 1 step: 767, loss is 2.125296
epoch: 1 step: 768, loss is 2.2820292
epoch: 1 step: 769, loss is 2.2235994
epoch: 1 step: 770, loss is 1.4450164
epoch: 1 step: 771, loss is 1.0859127
epoch: 1 step: 772, loss is 1.7607517
epoch: 1 step: 773, loss is 1.2833142
epoch: 1 step: 774, loss is 1.4150616
epoch: 1 step: 775, loss is 1.352462
epoch: 1 step: 776, loss is 0.935359
epoch: 1 step: 777, loss is 1.9115752
epoch: 1 step: 778, loss is 1.0090393
epoch: 1 step: 779, loss is 2.2120566
epoch: 1 step: 780, loss is 2.8659718
epoch: 1 step: 781, loss is 1.0883211
epoch: 1 step: 782, loss is 1.0360981
epoch: 1 step: 783, loss is 0.98880714
epoch: 1 step: 784, loss is 2.0376098
epoch: 1 step: 785, loss is 1.0039577
epoch: 1 step: 786, loss is 1.9061139
epoch: 1 step: 787, loss is 1.1017721
epoch: 1 step: 788, loss is 1.6537746
epoch: 1 step: 789, loss is 1.7032864
epoch: 1 step: 790, loss is 1.5252613
epoch: 1 step: 791, loss is 1.1611335
epoch: 1 step: 792, loss is 1.4555403
epoch: 1 step: 793, loss is 2.9008605
epoch: 1 step: 794, loss is 0.97515416
epoch: 1 step: 795, loss is 1.105194
epoch: 1 step: 796, loss is 1.2381573
epoch: 1 step: 797, loss is 2.3886082
epoch: 1 step: 798, loss is 2.69945
epoch: 1 step: 799, loss is 1.258339
epoch: 1 step: 800, loss is 1.0384194
epoch: 1 step: 801, loss is 2.2334986
epoch: 1 step: 802, loss is 1.1816546
epoch: 1 step: 803, loss is 1.4547788
epoch: 1 step: 804, loss is 1.3837935
epoch: 1 step: 805, loss is 2.0567157
epoch: 1 step: 806, loss is 0.9696233
epoch: 1 step: 807, loss is 0.85108083
epoch: 1 step: 808, loss is 0.9614727
epoch: 1 step: 809, loss is 0.6444842
epoch: 1 step: 810, loss is 2.6879709
epoch: 1 step: 811, loss is 1.8147883
epoch: 1 step: 812, loss is 1.336968
epoch: 1 step: 813, loss is 1.1472121
epoch: 1 step: 814, loss is 1.7446545
epoch: 1 step: 815, loss is 1.4696892
epoch: 1 step: 816, loss is 1.9627728
epoch: 1 step: 817, loss is 0.89654475
epoch: 1 step: 818, loss is 1.2139798
epoch: 1 step: 819, loss is 1.5103207
epoch: 1 step: 820, loss is 2.2895432
epoch: 1 step: 821, loss is 1.3893139
epoch: 1 step: 822, loss is 1.2424634
epoch: 1 step: 823, loss is 1.9037386
epoch: 1 step: 824, loss is 1.0120343
epoch: 1 step: 825, loss is 1.0779756
epoch: 1 step: 826, loss is 1.6204398
epoch: 1 step: 827, loss is 1.0996839
epoch: 1 step: 828, loss is 1.1150961
epoch: 1 step: 829, loss is 1.8815032
epoch: 1 step: 830, loss is 1.2137359
epoch: 1 step: 831, loss is 1.2617774
epoch: 1 step: 832, loss is 1.0466521
epoch: 1 step: 833, loss is 1.5369562
epoch: 1 step: 834, loss is 1.6386901
epoch: 1 step: 835, loss is 2.0249252
epoch: 1 step: 836, loss is 0.5760666
epoch: 1 step: 837, loss is 1.7640389
epoch: 1 step: 838, loss is 1.4437629
epoch: 1 step: 839, loss is 0.86503774
epoch: 1 step: 840, loss is 1.240575
epoch: 1 step: 841, loss is 0.78614163
epoch: 1 step: 842, loss is 1.6547961
epoch: 1 step: 843, loss is 1.9955086
epoch: 1 step: 844, loss is 1.4628034
epoch: 1 step: 845, loss is 1.3755858
epoch: 1 step: 846, loss is 1.8454322
epoch: 1 step: 847, loss is 2.0809815
epoch: 1 step: 848, loss is 1.7316154
epoch: 1 step: 849, loss is 1.45128
epoch: 1 step: 850, loss is 1.8249403
epoch: 1 step: 851, loss is 1.3672258
epoch: 1 step: 852, loss is 1.2588383
epoch: 1 step: 853, loss is 1.1747903
epoch: 1 step: 854, loss is 1.4708637
epoch: 1 step: 855, loss is 1.9238496
epoch: 1 step: 856, loss is 1.4171104
epoch: 1 step: 857, loss is 0.74266225
epoch: 1 step: 858, loss is 1.0323285
epoch: 1 step: 859, loss is 1.5013592
epoch: 1 step: 860, loss is 2.0671127
epoch: 1 step: 861, loss is 1.6654607
epoch: 1 step: 862, loss is 0.81948894
epoch: 1 step: 863, loss is 1.4823669
epoch: 1 step: 864, loss is 1.4930465
epoch: 1 step: 865, loss is 1.4738389
epoch: 1 step: 866, loss is 1.3726466
epoch: 1 step: 867, loss is 0.8479878
epoch: 1 step: 868, loss is 1.989243
epoch: 1 step: 869, loss is 1.2593791
epoch: 1 step: 870, loss is 1.323216
epoch: 1 step: 871, loss is 2.722396
epoch: 1 step: 872, loss is 1.3071216
epoch: 1 step: 873, loss is 1.2357074
epoch: 1 step: 874, loss is 1.1517795
epoch: 1 step: 875, loss is 1.2088054
epoch: 1 step: 876, loss is 1.1895696
epoch: 1 step: 877, loss is 0.6272737
epoch: 1 step: 878, loss is 0.7024939
epoch: 1 step: 879, loss is 1.6122398
epoch: 1 step: 880, loss is 0.91959953
epoch: 1 step: 881, loss is 2.0121446
epoch: 1 step: 882, loss is 0.9743911
epoch: 1 step: 883, loss is 1.7301587
epoch: 1 step: 884, loss is 1.3133391
epoch: 1 step: 885, loss is 1.5752593
epoch: 1 step: 886, loss is 1.8155209
epoch: 1 step: 887, loss is 2.0357287
epoch: 1 step: 888, loss is 1.2857769
epoch: 1 step: 889, loss is 2.3743215
epoch: 1 step: 890, loss is 1.7479724
epoch: 1 step: 891, loss is 1.4506453
epoch: 1 step: 892, loss is 1.3023769
epoch: 1 step: 893, loss is 1.2061354
epoch: 1 step: 894, loss is 1.5174105
epoch: 1 step: 895, loss is 2.6763844
epoch: 1 step: 896, loss is 1.4993014
epoch: 1 step: 897, loss is 0.89111876
epoch: 1 step: 898, loss is 1.0803541
epoch: 1 step: 899, loss is 1.1553372
epoch: 1 step: 900, loss is 1.3381749
epoch: 1 step: 901, loss is 1.357391
epoch: 1 step: 902, loss is 1.5892615
epoch: 1 step: 903, loss is 1.5276378
epoch: 1 step: 904, loss is 1.2582273
epoch: 1 step: 905, loss is 1.3074192
epoch: 1 step: 906, loss is 1.4231623
epoch: 1 step: 907, loss is 1.7495323
epoch: 1 step: 908, loss is 0.96390444
epoch: 1 step: 909, loss is 1.586216
epoch: 1 step: 910, loss is 1.0915891
epoch: 1 step: 911, loss is 1.8093935
epoch: 1 step: 912, loss is 1.1620821
epoch: 1 step: 913, loss is 2.5027232
epoch: 1 step: 914, loss is 3.4418967
epoch: 1 step: 915, loss is 1.2269223
epoch: 1 step: 916, loss is 0.758797
epoch: 1 step: 917, loss is 1.6124766
epoch: 1 step: 918, loss is 1.1888527
epoch: 1 step: 919, loss is 2.1860857
epoch: 1 step: 920, loss is 1.4980818
epoch: 1 step: 921, loss is 1.1740198
epoch: 1 step: 922, loss is 1.0307492
epoch: 1 step: 923, loss is 1.5164996
epoch: 1 step: 924, loss is 1.47919
epoch: 1 step: 925, loss is 1.4799849
epoch: 1 step: 926, loss is 0.87878925
epoch: 1 step: 927, loss is 1.3661126
epoch: 1 step: 928, loss is 1.32602
epoch: 1 step: 929, loss is 1.0411491
epoch: 1 step: 930, loss is 1.5447545
epoch: 1 step: 931, loss is 0.8936902
epoch: 1 step: 932, loss is 0.63189197
epoch: 1 step: 933, loss is 2.2054086
epoch: 1 step: 934, loss is 3.8210456
epoch: 1 step: 935, loss is 1.6392655
epoch: 1 step: 936, loss is 1.2055237
epoch: 1 step: 937, loss is 1.7756335
epoch: 1 step: 938, loss is 1.3743494
epoch: 1 step: 939, loss is 1.4543933
epoch: 1 step: 940, loss is 1.7049619
epoch: 1 step: 941, loss is 1.1750684
epoch: 1 step: 942, loss is 1.0356847
epoch: 1 step: 943, loss is 1.4004953
epoch: 1 step: 944, loss is 1.3185633
epoch: 1 step: 945, loss is 3.031432
epoch: 1 step: 946, loss is 1.7927536
epoch: 1 step: 947, loss is 1.3307898
epoch: 1 step: 948, loss is 1.2131346
epoch: 1 step: 949, loss is 1.1113427
epoch: 1 step: 950, loss is 1.1364393
epoch: 1 step: 951, loss is 1.1454207
epoch: 1 step: 952, loss is 1.9356118
epoch: 1 step: 953, loss is 1.8053191
epoch: 1 step: 954, loss is 1.0849828
epoch: 1 step: 955, loss is 1.4529324
epoch: 1 step: 956, loss is 1.1357105
epoch: 1 step: 957, loss is 1.2413461
epoch: 1 step: 958, loss is 1.2732639
epoch: 1 step: 959, loss is 1.7117423
epoch: 1 step: 960, loss is 1.7016387
epoch: 1 step: 961, loss is 1.6044356
epoch: 1 step: 962, loss is 1.6920215
epoch: 1 step: 963, loss is 1.332428
epoch: 1 step: 964, loss is 2.4065835
epoch: 1 step: 965, loss is 1.7092701
epoch: 1 step: 966, loss is 1.7453228
epoch: 1 step: 967, loss is 1.9132671
epoch: 1 step: 968, loss is 1.3304029
epoch: 1 step: 969, loss is 1.0579582
epoch: 1 step: 970, loss is 1.138217
epoch: 1 step: 971, loss is 1.448759
epoch: 1 step: 972, loss is 1.2868248
epoch: 1 step: 973, loss is 1.6847154
epoch: 1 step: 974, loss is 1.9714185
epoch: 1 step: 975, loss is 1.9061745
epoch: 1 step: 976, loss is 2.2852402
epoch: 1 step: 977, loss is 1.1503751
epoch: 1 step: 978, loss is 0.8857276
epoch: 1 step: 979, loss is 1.4188049
epoch: 1 step: 980, loss is 1.0287291
epoch: 1 step: 981, loss is 1.1392806
epoch: 1 step: 982, loss is 1.4414032
epoch: 1 step: 983, loss is 0.68552446
epoch: 1 step: 984, loss is 1.9038157
epoch: 1 step: 985, loss is 0.88953227
epoch: 1 step: 986, loss is 0.9495513
epoch: 1 step: 987, loss is 2.2432482
epoch: 1 step: 988, loss is 1.036971
epoch: 1 step: 989, loss is 1.908634
epoch: 1 step: 990, loss is 1.0105088
epoch: 1 step: 991, loss is 1.7258458
epoch: 1 step: 992, loss is 1.126885
epoch: 1 step: 993, loss is 1.4759325
epoch: 1 step: 994, loss is 1.0167694
epoch: 1 step: 995, loss is 1.7607563
epoch: 1 step: 996, loss is 1.7025166
epoch: 1 step: 997, loss is 2.642485
epoch: 1 step: 998, loss is 1.0083256
epoch: 1 step: 999, loss is 1.8801259
epoch: 1 step: 1000, loss is 1.4351443
epoch: 1 step: 1001, loss is 1.8198338
epoch: 1 step: 1002, loss is 1.4604013
epoch: 1 step: 1003, loss is 1.7207165
epoch: 1 step: 1004, loss is 1.6171037
epoch: 1 step: 1005, loss is 1.7900815
epoch: 1 step: 1006, loss is 1.7413344
epoch: 1 step: 1007, loss is 1.4385828
epoch: 1 step: 1008, loss is 1.3763657
epoch: 1 step: 1009, loss is 1.0850443
epoch: 1 step: 1010, loss is 1.6522856
epoch: 1 step: 1011, loss is 1.2552824
epoch: 1 step: 1012, loss is 1.8896031
epoch: 1 step: 1013, loss is 0.84858
epoch: 1 step: 1014, loss is 1.303438
epoch: 1 step: 1015, loss is 1.2283741
epoch: 1 step: 1016, loss is 1.9411373
epoch: 1 step: 1017, loss is 2.4786792
epoch: 1 step: 1018, loss is 1.629952
epoch: 1 step: 1019, loss is 1.9547821
epoch: 1 step: 1020, loss is 1.5026599
epoch: 1 step: 1021, loss is 1.8007478
epoch: 1 step: 1022, loss is 0.9431855
epoch: 1 step: 1023, loss is 1.4203576
epoch: 1 step: 1024, loss is 1.6090465
epoch: 1 step: 1025, loss is 1.5806173
epoch: 1 step: 1026, loss is 1.3195636
epoch: 1 step: 1027, loss is 1.6350014
epoch: 1 step: 1028, loss is 2.3332777
epoch: 1 step: 1029, loss is 1.2295918
epoch: 1 step: 1030, loss is 1.3075395
epoch: 1 step: 1031, loss is 1.8433943
epoch: 1 step: 1032, loss is 1.4248459
epoch: 1 step: 1033, loss is 1.2253772
epoch: 1 step: 1034, loss is 1.3848251
epoch: 1 step: 1035, loss is 1.2664615
epoch: 1 step: 1036, loss is 0.94997245
epoch: 1 step: 1037, loss is 0.6666056
epoch: 1 step: 1038, loss is 1.2284954
epoch: 1 step: 1039, loss is 1.0309454
epoch: 1 step: 1040, loss is 1.0446043
epoch: 1 step: 1041, loss is 2.2751718
epoch: 1 step: 1042, loss is 0.8616876
epoch: 1 step: 1043, loss is 2.6705627
epoch: 1 step: 1044, loss is 1.63911
epoch: 1 step: 1045, loss is 1.388204
epoch: 1 step: 1046, loss is 1.3926078
epoch: 1 step: 1047, loss is 1.4115934
epoch: 1 step: 1048, loss is 1.1283879
epoch: 1 step: 1049, loss is 1.2134101
epoch: 1 step: 1050, loss is 1.2546201
epoch: 1 step: 1051, loss is 1.0554248
epoch: 1 step: 1052, loss is 1.052341
epoch: 1 step: 1053, loss is 1.3388793
epoch: 1 step: 1054, loss is 1.2750398
epoch: 1 step: 1055, loss is 1.7651321
epoch: 1 step: 1056, loss is 0.55550873
epoch: 1 step: 1057, loss is 2.2490916
epoch: 1 step: 1058, loss is 0.8794766
epoch: 1 step: 1059, loss is 1.1920084
epoch: 1 step: 1060, loss is 1.5519762
epoch: 1 step: 1061, loss is 0.9226018
epoch: 1 step: 1062, loss is 1.5778514
epoch: 1 step: 1063, loss is 0.8694902
epoch: 1 step: 1064, loss is 1.6215457
epoch: 1 step: 1065, loss is 0.7417243
epoch: 1 step: 1066, loss is 0.7901127
epoch: 1 step: 1067, loss is 2.0685709
epoch: 1 step: 1068, loss is 1.6832684
epoch: 1 step: 1069, loss is 1.3582491
epoch: 1 step: 1070, loss is 1.0873384
epoch: 1 step: 1071, loss is 1.1739986
epoch: 1 step: 1072, loss is 2.0017753
epoch: 1 step: 1073, loss is 1.0707401
epoch: 1 step: 1074, loss is 2.0287397
epoch: 1 step: 1075, loss is 1.323993
epoch: 1 step: 1076, loss is 1.5487818
epoch: 1 step: 1077, loss is 2.4028199
epoch: 1 step: 1078, loss is 1.7405751
epoch: 1 step: 1079, loss is 1.1860071
epoch: 1 step: 1080, loss is 1.9534358
epoch: 1 step: 1081, loss is 1.29257
epoch: 1 step: 1082, loss is 1.356086
epoch: 1 step: 1083, loss is 1.0340477
epoch: 1 step: 1084, loss is 1.562695
epoch: 1 step: 1085, loss is 1.7380655
epoch: 1 step: 1086, loss is 0.7137819
epoch: 1 step: 1087, loss is 2.6704192
epoch: 1 step: 1088, loss is 1.0321915
epoch: 1 step: 1089, loss is 1.9998426
epoch: 1 step: 1090, loss is 2.0359535
epoch: 1 step: 1091, loss is 2.1411073
epoch: 1 step: 1092, loss is 1.6647575
epoch: 1 step: 1093, loss is 1.1170344
epoch: 1 step: 1094, loss is 1.8599546
epoch: 1 step: 1095, loss is 1.5026112
epoch: 1 step: 1096, loss is 1.8244514
epoch: 1 step: 1097, loss is 1.2119541
epoch: 1 step: 1098, loss is 1.4266459
epoch: 1 step: 1099, loss is 1.1422569
epoch: 1 step: 1100, loss is 2.3389933
epoch: 1 step: 1101, loss is 0.75062865
epoch: 1 step: 1102, loss is 1.7279513
epoch: 1 step: 1103, loss is 2.5207427
epoch: 1 step: 1104, loss is 0.8943802
epoch: 1 step: 1105, loss is 1.835564
epoch: 1 step: 1106, loss is 0.8827565
epoch: 1 step: 1107, loss is 1.2942319
epoch: 1 step: 1108, loss is 0.8752702
epoch: 1 step: 1109, loss is 2.016707
epoch: 1 step: 1110, loss is 1.5830765
epoch: 1 step: 1111, loss is 1.7888567
epoch: 1 step: 1112, loss is 1.0753819
epoch: 1 step: 1113, loss is 1.5425149
epoch: 1 step: 1114, loss is 0.9805898
epoch: 1 step: 1115, loss is 1.3251598
epoch: 1 step: 1116, loss is 1.3167626
epoch: 1 step: 1117, loss is 1.9179026
epoch: 1 step: 1118, loss is 1.6388515
epoch: 1 step: 1119, loss is 1.4332074
epoch: 1 step: 1120, loss is 0.8566106
epoch: 1 step: 1121, loss is 1.6719701
epoch: 1 step: 1122, loss is 1.9416367
epoch: 1 step: 1123, loss is 1.1260948
epoch: 1 step: 1124, loss is 1.0954722
epoch: 1 step: 1125, loss is 1.1506572
epoch: 1 step: 1126, loss is 1.7103226
epoch: 1 step: 1127, loss is 1.5115188
epoch: 1 step: 1128, loss is 1.1531181
epoch: 1 step: 1129, loss is 1.3701619
epoch: 1 step: 1130, loss is 1.9702977
epoch: 1 step: 1131, loss is 1.952686
epoch: 1 step: 1132, loss is 1.5685352
epoch: 1 step: 1133, loss is 1.483937
epoch: 1 step: 1134, loss is 1.4493775
epoch: 1 step: 1135, loss is 1.3166554
epoch: 1 step: 1136, loss is 0.818511
epoch: 1 step: 1137, loss is 1.8878976
epoch: 1 step: 1138, loss is 1.2692714
epoch: 1 step: 1139, loss is 1.6070595
epoch: 1 step: 1140, loss is 1.134883
epoch: 1 step: 1141, loss is 1.2512453
epoch: 1 step: 1142, loss is 0.67773414
epoch: 1 step: 1143, loss is 0.9644441
Train epoch time: 751072.223 ms, per step time: 657.106 ms

因为FCN网络在训练的过程中需要大量的训练数据和训练轮数,这里只提供了小数据单个epoch的训练来演示loss收敛的过程,下文中使用已训练好的权重文件进行模型评估和推理效果的展示。

模型评估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)

ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt (1.00 GB)

file_sizes: 100%|███████████████████████████| 1.08G/1.08G [00:05<00:00, 181MB/s]
Successfully downloaded file to FCN8s.ckpt
-

[9]:

{'pixel accuracy': 0.9727996527021614,
 'mean pixel accuracy': 0.9392028431280385,
 'mean IoU': 0.8921996881458918,
 'frequency weighted IoU': 0.9475711324126546}

模型推理

使用训练的网络对模型推理结果进行展示。

import cv2
import matplotlib.pyplot as plt

net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

总结

FCN的核心贡献在于提出使用全卷积层,通过学习让图片实现端到端分割。与传统使用CNN进行图像分割的方法相比,FCN有两大明显的优点:一是可以接受任意大小的输入图像,无需要求所有的训练图像和测试图像具有固定的尺寸。二是更加高效,避免了由于使用像素块而带来的重复存储和计算卷积的问题。

同时FCN网络也存在待改进之处:

一是得到的结果仍不够精细。进行8倍上采样虽然比32倍的效果好了很多,但是上采样的结果仍比较模糊和平滑,尤其是边界处,网络对图像中的细节不敏感。 二是对各个像素进行分类,没有充分考虑像素与像素之间的关系(如不连续性和相似性)。忽略了在通常的基于像素分类的分割方法中使用的空间规整(spatial regularization)步骤,缺乏空间一致性。

引用

[1]Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for Semantic Segmentation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.

;