Pix2Pix实现图像转换
Pix2Pix概述
Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器。
传统上,尽管此类任务的目标都是相同的从像素预测像素,但每项都是用单独的专用机器来处理的。而Pix2Pix使用的网络作为一个通用框架,使用相同的架构和目标,只在不同的数据上进行训练,即可得到令人满意的结果,鉴于此许多人已经使用此网络发布了他们自己的艺术作品。
基础原理
cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。
在教程开始前,首先定义一些在整个过程中需要用到的符号:
- 𝑥:代表观测图像的数据。
- 𝑧:代表随机噪声的数据。
- 𝑦=𝐺(𝑥,𝑧):生成器网络,给出由观测图像𝑥与随机噪声𝑧生成的“假”图片,其中𝑥来自于训练数据而非生成器。
- 𝐷(𝑥,𝐺(𝑥,𝑧)):判别器网络,给出图像判定为真实图像的概率,其中𝑥来自于训练数据,𝐺(𝑥,𝑧)来自于生成器。
cGAN的目标可以表示为:
该公式是cGAN的损失函数,D
想要尽最大努力去正确分类真实图像与“假”图像,也就是使参数𝑙𝑜𝑔𝐷(𝑥,𝑦)最大化;而G
则尽最大努力用生成的“假”图像𝑦欺骗D
,避免被识破,也就是使参数𝑙𝑜𝑔(1−𝐷(𝐺(𝑥,𝑧)))最小化。cGAN的目标可简化为:
为了对比cGAN和GAN的不同,我们将GAN的目标也进行了说明:
从公式可以看出,GAN直接由随机噪声𝑧�生成“假”图像,不借助观测图像𝑥�的任何信息。过去的经验告诉我们,GAN与传统损失混合使用是有好处的,判别器的任务不变,依旧是区分真实图像与“假”图像,但是生成器的任务不仅要欺骗判别器,还要在传统损失的基础上接近训练数据。假设cGAN与L1正则化混合使用,那么有:
进而得到最终目标:
图像转换问题本质上其实就是像素到像素的映射问题,Pix2Pix使用完全一样的网络结构和目标函数,仅更换不同的训练数据集就能分别实现以上的任务。本任务将借助MindSpore框架来实现Pix2Pix的应用。
准备环节¶
配置环境文件
本案例在GPU,CPU和Ascend平台的动静态模式都支持。
python版本:Python 3.9.19
依赖环境安装
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
完整的环境
pip list
Package Version
------------------------------ --------------
absl-py 2.1.0
aiofiles 22.1.0
aiosqlite 0.20.0
altair 5.3.0
annotated-types 0.7.0
anyio 4.4.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
astroid 3.2.2
asttokens 2.0.5
astunparse 1.6.3
attrs 23.2.0
auto-tune 0.1.0
autopep8 1.5.5
Babel 2.15.0
backcall 0.2.0
beautifulsoup4 4.12.3
black 24.4.2
bleach 6.1.0
certifi 2024.6.2
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
colorama 0.4.6
comm 0.2.1
contextlib2 21.6.0
contourpy 1.2.1
cycler 0.12.1
dataflow 0.0.1
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.8
dnspython 2.6.1
download 0.3.5
easydict 1.13
email_validator 2.2.0
entrypoints 0.4
exceptiongroup 1.2.0
executing 0.8.3
fastapi 0.111.0
fastapi-cli 0.0.4
fastjsonschema 2.20.0
ffmpy 0.3.2
filelock 3.15.3
flake8 3.8.4
fonttools 4.53.0
fqdn 1.5.1
fsspec 2024.6.0
gitdb 4.0.11
GitPython 3.1.43
gradio 4.26.0
gradio_client 0.15.1
h11 0.14.0
hccl 0.1.0
hccl-parser 0.1
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.4
idna 3.7
importlib-metadata 7.0.1
importlib_resources 6.4.0
iniconfig 2.0.0
ipykernel 6.28.0
ipympl 0.9.4
ipython 8.15.0
ipython-genutils 0.2.0
ipywidgets 8.1.3
isoduration 20.11.0
isort 5.13.2
jedi 0.17.2
Jinja2 3.1.4
joblib 1.4.2
json5 0.9.25
jsonpointer 3.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
jupyter_client 7.4.9
jupyter_core 5.7.2
jupyter-events 0.10.0
jupyter-lsp 2.2.5
jupyter-resource-usage 0.7.2
jupyter_server 2.14.1
jupyter_server_fileid 0.9.2
jupyter-server-mathjax 0.2.6
jupyter_server_terminals 0.5.3
jupyter_server_ydoc 0.8.0
jupyter-ydoc 0.2.5
jupyterlab 3.6.7
jupyterlab_code_formatter 2.2.1
jupyterlab_git 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp 4.3.0
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.2
jupyterlab-system-monitor 0.8.0
jupyterlab-topbar 0.6.1
jupyterlab_widgets 3.0.11
kiwisolver 1.4.5
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
matplotlib-inline 0.1.6
mccabe 0.6.1
mdurl 0.1.2
mindspore 2.2.14
mindvision 0.1.0
mistune 3.0.2
ml_collections 0.1.1
mpmath 1.3.0
msadvisor 1.0.0
mypy-extensions 1.0.0
nbclassic 1.1.0
nbclient 0.10.0
nbconvert 7.16.4
nbdime 4.0.1
nbformat 5.10.4
nest-asyncio 1.6.0
notebook 6.5.7
notebook_shim 0.2.4
numpy 1.26.4
op-compile-tool 0.1.0
op-gen 0.1
op-test-frame 0.1
opc-tool 0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python 4.10.0.84
opencv-python-headless 4.10.0.84
orjson 3.10.5
overrides 7.7.0
packaging 23.2
pandas 2.2.2
pandocfilters 1.5.1
parso 0.7.1
pathlib2 2.3.7.post1
pathspec 0.12.1
pexpect 4.8.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.1
platformdirs 4.2.2
pluggy 1.5.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
protobuf 5.27.1
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
pycodestyle 2.6.0
pycparser 2.22
pydantic 2.7.4
pydantic_core 2.18.4
pydocstyle 6.3.0
pydub 0.25.1
pyflakes 2.2.0
Pygments 2.15.1
pylint 3.2.3
pyparsing 3.1.2
pytest 8.0.0
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-json-logger 2.0.7
python-jsonrpc-server 0.4.0
python-language-server 0.36.2
python-multipart 0.0.9
pytoolconfig 1.3.1
pytz 2024.1
PyYAML 6.0.1
pyzmq 25.1.2
referencing 0.35.1
requests 2.32.3
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.7.1
rope 1.13.0
rpds-py 0.18.1
ruff 0.4.10
schedule-search 0.0.1
scikit-learn 1.5.0
scipy 1.13.1
semantic-version 2.10.0
Send2Trash 1.8.3
setuptools 69.5.1
shellingham 1.5.4
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
snowballstemmer 2.2.0
soupsieve 2.5
stack-data 0.2.0
starlette 0.37.2
sympy 1.12.1
synr 0.5.0
te 0.4.0
terminado 0.18.1
threadpoolctl 3.5.0
tinycss2 1.3.0
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.1
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
typer 0.12.3
types-python-dateutil 2.9.0.20240316
typing_extensions 4.11.0
tzdata 2024.1
ujson 5.10.0
uri-template 1.3.0
urllib3 2.2.2
uvicorn 0.30.1
uvloop 0.19.0
watchfiles 0.22.0
wcwidth 0.2.5
webcolors 24.6.0
webencodings 0.5.1
websocket-client 1.8.0
websockets 11.0.3
wheel 0.43.0
widgetsnbextension 4.0.11
y-py 0.6.2
yapf 0.40.2
ypy-websocket 0.8.4
zipp 3.17.0
准备数据
在本教程中,我们将使用指定数据集,该数据集是已经经过处理的外墙(facades)数据,可以直接使用mindspore.dataset的方法读取。
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"
download(url, "./dataset", kind="tar", replace=True)
数据展示
调用Pix2PixDataset
和create_train_dataset
读取训练集,这里我们直接下载已经处理好的数据集。
from mindspore import dataset as ds
import matplotlib.pyplot as plt
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
plt.subplot(3, 10, i)
plt.axis("off")
plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()
创建网络
当处理完数据后,就可以来进行网络的搭建了。网络搭建将逐一详细讨论生成器、判别器和损失函数。生成器G用到的是U-Net结构,输入的轮廓图𝑥编码再解码成真是图片,判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图 𝑥的条件下,对于生成的图片𝐺(𝑥)判断为假,对于真实判断为真。
生成器G结构
U-Net是德国Freiburg大学模式识别和图像处理组提出的一种全卷积结构。它分为两个部分,其中左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径,扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。网络模型整体是一个U形的结构,因此被叫做U-Net。和常见的先降采样到低维度,再升采样到原始分辨率的编解码结构的网络相比,U-Net的区别是加入skip-connection,对应的feature maps和decode之后的同样大小的feature maps按通道拼一起,用来保留不同分辨率下像素级的细节信息。
定义UNet Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
class UNetSkipConnectionBlock(nn.Cell):
def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
super(UNetSkipConnectionBlock, self).__init__()
down_norm = nn.BatchNorm2d(inner_nc)
up_norm = nn.BatchNorm2d(outer_nc)
use_bias = False
if norm_mode == 'instance':
down_norm = nn.BatchNorm2d(inner_nc, affine=False)
up_norm = nn.BatchNorm2d(outer_nc, affine=False)
use_bias = True
if in_planes is None:
in_planes = outer_nc
down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
down_relu = nn.LeakyReLU(alpha)
up_relu = nn.ReLU()
if outermost:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, pad_mode='pad')
down = [down_conv]
up = [up_relu, up_conv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [down_relu, down_conv]
up = [up_relu, up_conv, up_norm]
model = down + up
else:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [down_relu, down_conv, down_norm]
up = [up_relu, up_conv, up_norm]
model = down + [submodule] + up
if dropout:
model.append(nn.Dropout(p=0.5))
self.model = nn.SequentialCell(model)
self.skip_connections = not outermost
def construct(self, x):
out = self.model(x)
if self.skip_connections:
out = ops.concat((out, x), axis=1)
return out
基于UNet的生成器
class UNetGenerator(nn.Cell):
def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
super(UNetGenerator, self).__init__()
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
norm_mode=norm_mode, innermost=True)
for _ in range(n_layers - 5):
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode, dropout=dropout)
unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
outermost=True, norm_mode=norm_mode)
def construct(self, x):
return self.model(x)
原始cGAN的输入是条件x和噪声z两种信息,这里的生成器只使用了条件信息,因此不能生成多样性的结果。因此Pix2Pix在训练和测试时都使用了dropout,这样可以生成多样性的结果。
基于PatchGAN的判别器
判别器使用的PatchGAN结构,可看做卷积。生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。
import mindspore.nn as nn
class ConvNormRelu(nn.Cell):
def __init__(self,
in_planes,
out_planes,
kernel_size=4,
stride=2,
alpha=0.2,
norm_mode='batch',
pad_mode='CONSTANT',
use_relu=True,
padding=None):
super(ConvNormRelu, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if not padding:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
has_bias=has_bias, padding=padding)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class Discriminator(nn.Cell):
def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
super(Discriminator, self).__init__()
kernel_size = 4
layers = [
nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
nn.LeakyReLU(alpha)
]
nf_mult = ndf
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
self.features = nn.SequentialCell(layers)
def construct(self, x, y):
x_y = ops.concat((x, y), axis=1)
output = self.features(x_y)
return output
Pix2Pix的生成器和判别器初始化
实例化Pix2Pix生成器和判别器。
import mindspore.nn as nn
from mindspore.common import initializer as init
g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'
net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class Pix2Pix(nn.Cell):
"""Pix2Pix模型网络"""
def __init__(self, discriminator, generator):
super(Pix2Pix, self).__init__(auto_prefix=True)
self.net_discriminator = discriminator
self.net_generator = generator
def construct(self, reala):
fakeb = self.net_generator(reala)
return fakeb
训练
训练分为两个主要部分:训练判别器和训练生成器。训练判别器的目的是最大程度地提高判别图像真伪的概率。训练生成器是希望能产生更好的虚假图像。在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计。
下面进行训练:
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():
lrs = [lr] * dataset_size * n_epochs
lr_epoch = 0
for epoch in range(n_epochs_decay):
lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
lrs += [lr_epoch] * dataset_size
lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forword_dis(reala, realb):
lambda_dis = 0.5
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
pred1 = net_discriminator(reala, realb)
loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
loss_dis = loss_d * lambda_dis
return loss_dis
def forword_gan(reala, realb):
lambda_gan = 0.5
lambda_l1 = 100
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
loss_1 = loss_f(pred0, ops.ones_like(pred0))
loss_2 = l1_loss(fakeb, realb)
loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
return loss_gan
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
def train_step(reala, realb):
loss_dis, d_grads = grad_d(reala, realb)
loss_gan, g_grads = grad_g(reala, realb)
d_opt(d_grads)
g_opt(g_grads)
return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)
for epoch in range(epoch_num):
for i, data in enumerate(data_loader):
start_time = datetime.datetime.now()
input_image = Tensor(data["input_images"])
target_image = Tensor(data["target_images"])
dis_loss, gen_loss = train_step(input_image, target_image)
end_time = datetime.datetime.now()
delta = (end_time - start_time).microseconds
if i % 2 == 0:
print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
d_losses.append(dis_loss.asnumpy())
g_losses.append(gen_loss.asnumpy())
if (epoch + 1) == epoch_num:
mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
推理¶
获取上述训练过程完成后的ckpt文件,通过load_checkpoint和load_param_into_net将ckpt中的权重参数导入到模型中,获取数据进行推理并对推理的效果图进行演示(由于时间问题,训练过程只进行了3个epoch,可根据需求调整epoch)。
from mindspore import load_checkpoint, load_param_into_net
param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.subplot(2, 10, i + 11)
plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
引用
[1] Phillip Isola,Jun-Yan Zhu,Tinghui Zhou,Alexei A. Efros. Image-to-Image Translation with Conditional Adversarial Networks.[J]. CoRR,2016,abs/1611.07004.