Bootstrap

昇思25天学习打卡营第18天|Pix2Pix实现图像转换

课程打卡凭证

Pix2Pix模型

Pix2Pix是一种基于条件生成对抗网络(Conditional Generative Adversarial Network,cGAN)的图像转译模型,能够将一个域中的输入图像转换为另一个域中的对应图像。

cGAN是GAN(详见昇思25天学习打卡营第16天|GAN图像生成-CSDN博客)的一种,与GAN不同的是,GAN仅使用随机噪声作为输入,而cGAN除了随机噪声外,还使用条件信息作为输入。总的来说,cGAN在标准GAN的基础上引入了条件信息,使得生成的结果更加可控和定制化,扩展了 GAN 的应用范围和效果。

Pix2Pix模型同样是由生成器(Generator)和判别器(Discriminator)组成。其中,生成器(通常采用U-Net网络结构)负责根据输入的条件图像(如语义标签图或简笔画)生成对应的输出图像(如真实图片),而判别器(通常采用PatchGAN结构)则负责判断生成的图像是否足够真实,以及是否与输入的条件图像相匹配。因此,Pix2Pix的损失函数为生成器损失和判别器损失之和,具体如下图所示。

训练过程

数据加载与展示

模型训练

构建生成器

采用U-Net架构,构建一个用于图像到图像翻译任务的跳跃连接块,通过递归嵌套子模块实现U-Net的编码器-解码器结构。这样的结构可以在保留高分辨率特征的同时,通过下采样和上采样逐步提取和重建图像的多尺度特征。

通过嵌套多个UNet Skip Connection Block来构建U-Net架构的生成器类,用于图像到图像翻译任务,同时,在每个UNet Skip Connection Block中实现跳跃连接,以保留高分辨率特征。

构建判别器

Pix2Pix的判别器采用PatchGAN结构,它通过将输入图像分割成多个小块(patch),并对每个小块进行真假判断来实现局部图像质量的评估。这种结构有助于捕捉图像中的高频信息(如纹理和边缘),从而提高判别器的准确性。

这里定义了一个卷积层、归一化层和激活层的组合模块ConvNormRelu,以及一个多层卷积判别器 Discriminator。判别器接收真实图像和生成图像的拼接,通过多层卷积和激活函数进行特征提取,最终输出图像是真实还是生成的。

初始化生成器和判别器

导入必要的库与模块,定义生成器和判别器的各种参数,包括输入通道数、输出通道数、特征图通道数、层数、激活函数的斜率、初始化方法和增益值。

生成器初始化。

判别器初始化。

定义包含生成器和判别器的Pix2Pix类。

模型训练

导入必要的库与模块,定义训练参数,包括训练的总epoch数、检查点保存路径、数据集大小、学习率、学习率衰减的epoch数等。

学习率调度。

加载训练数据集,并获取每个epoch的步数。定义用于判别器的二元交叉熵损失和用于生成器的L1损失。

定义判别器和生成器的前向传播函数。

定义判别器和生成器的优化器,并使用value_and_grad函数计算损失及其对应的梯度。

定义一次训练步骤,计算损失及梯度,并更新模型参数。

开始模型训练。

模型推理

结果如图所示。

;