Bootstrap

【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像

深度卷积生成对抗网络(DCGAN):Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

1. 获取数据

原来书里的下载链接即原来知乎的何之源分享链接失效了,一篇简书里有下载地址,一共是51223张图片,尺寸是96×96×3,总大小272 MB

利用python查看图片的大小:

法I:

from PIL import Image
image = Image.open(dir + path[0])
imgSize = image.size  #大小
print(imgSize)
(96, 96)

法II:

import cv2
img = cv2.imread(dir + path[0])
sp = img.shape
print(sp)
(96, 96, 3)

2. 用GAN生成

2.1 Generator

CONVTRANSPOSE2D

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros')

官方文档,这是逆卷积,关于卷积可看之前的CNN猫狗二分类

在由多个输入平面组成的输入图像上应用二维转置卷积算子

这个模块可以看作是Conv2d相对于其输入的梯度。它也被称为分式卷积或解卷积(虽然它不是实际的解卷操作)

  • stride控制交叉相关的步幅
  • padding控制两边隐含的零填充量,以便进行dilation * (kernel_size - 1) - padding。详情请看下面的说明
  • output_padding 控制添加到输出形状一侧的额外尺寸。详情请看下面的说明
  • dilation控制核点之间的间距,也就是所谓的à trous算法。它比较难描述,但这个链接有一个很好的可视化的扩张作用
  • groups控制输入和输出之间的连接,in_channels和out_channels都必须被分组所除

对于本实验:

  • 输入维度:noiseSize × 1 × 1
  • kernel_size=4, stride=1, padding=0,第一次变化后:(n_generator_feature * 8) × 4 × 4
  • 当kernel_size=4, stride=2, padding=1时,输入的宽高刚好是第一次的两倍,第二次变化后:(n_generator_feature * 4) × 8 × 8
  • 第三次变化后:(n_generator_feature * 2) × 16 × 16
  • 第四次变化后:n_generator_feature × 32 × 32
  • 最后一层采用kernel_size=5, stride=3, padding=1,为了将32 × 32变为96 × 96
  • 最后用Tanh将输出图片的像素归一化到-1~1,如果希望归一化到0~1,需要使用Sigmoid
Generator = NetGenerator()
x = torch.rand(1, noiseSize, 1, 1)
y = Generator(x)
print(y.size())
torch.Size([1, 3, 96, 96])

2.2 Discriminator

LeakyReLU

官方手册inplace=True表示进行覆盖运算

Discriminator = NetDiscriminator()
x = torch.rand(1, 3, 96, 96)
y = Discriminator(x)
print(y.size())
torch.Size([1])

可以看出判别器和生成器的额网络几乎是对称的,从卷积核大小到padding、stride等设置,需要注意的是生成器的激活函数是ReLU,而判别器使用的是LeakyLeRU,二者并无本质区别,这里的选择更多的是经验总结。每一个样本经过判别器后,输出一个0~1的数,表示的是这个样本是真图片的概率

2.3 其它细节

torch.utils.data.DataLoader官方文档

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)中文文档
betas (Tuple[float, float], 可选) – 平滑常数:用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)

torch.nn.BCELoss():计算target 和output 间的二值交叉熵(Binary Cross Entropy)官方文档,计算公式可见

关于(tqdm.tqdm)可见官方文档

2.4 训练思路

  • 训练判别器
    • 对于真图片,输出尽可能是1
    • 对于假图片,输出尽可能是0
  • 训练生成器
    • 对于假图片,输出尽可能是1

这里需要注意以下几点。

  • 训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
  • 在训练判别器时,需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中。因为在训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
  • 在训练判别器时,需要反向传播两次,一次是希望把真图片判为1,一次是希望把假图片判为0。也可以将这两者的数据放到一个batch中,进行一次前向传播和一次反向传播即可。但是人们发现,在一个batch中只包含真图片或只包含假图片的做法最好。
  • 对于假图片,在训练判别器时,我们希望它输出0;而在训练生成器时,我们希望它输出1.因此可以看到一对看似矛盾的代码 error_d_fake = criterion(output, fake_labels)和error_g = criterion(output, true_labels)。其实这也很好理解,判别器希望能够把假图片判别为fake_label,而生成器则希望能把他判别为true_label,判别器和生成器互相对抗提升。

接下来就是一些可视化的代码。每次可视化使用的噪声都是固定的fix_noises,因为这样便于我们比较对于相同的输入,生成器生成的图片是如何一步步提升的。另外,由于我们对输入的图片进行了归一化处理(-1~1),在可视化时则需要将它还原成原来的scale(0~1)

3. 全部代码

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False

# dir = '... your path/faces/'
dir = '/mnt/Data1/ysc/GAN'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223


noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 256
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generator


class NetGenerator(nn.Module):
    def __init__(self):
        super(NetGenerator,self).__init__()
        self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
            nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(n_generator_feature * 8),
            nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4
            nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 4),
            nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8
            nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 2),
            nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16
            nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature),
            nn.ReLU(True),      # (n_generator_feature) × 32 × 32
            nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()       # 3 * 96 * 96
        )

    def forward(self, input):
        return self.main(input)


class NetDiscriminator(nn.Module):
    def __init__(self):
        super(NetDiscriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32
            nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16
            nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8
            nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4
            nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()        # 输出一个概率
        )

    def forward(self, input):
        return self.main(input).view(-1)


def train():
    for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96
        real_image = Variable(image)
        real_image = real_image.cuda()

        if (i + 1) % d_every == 0:
            optimizer_d.zero_grad()
            output = Discriminator(real_image)      # 尽可能把真图片判为True
            error_d_real = criterion(output, true_labels)
            error_d_real.backward()

            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises).detach()       # 根据噪声生成假图
            fake_output = Discriminator(fake_img)       # 尽可能把假图片判为False
            error_d_fake = criterion(fake_output, fake_labels)
            error_d_fake.backward()
            optimizer_d.step()

        if (i + 1) % g_every == 0:
            optimizer_g.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises)        # 这里没有detach
            fake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为True
            error_g = criterion(fake_output, true_labels)
            error_g.backward()
            optimizer_g.step()


def show(num):
    fix_fake_imags = Generator(fix_noises)
    fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5

    # x = torch.rand(64, 3, 96, 96)
    fig = plt.figure(1)

    i = 1
    for image in fix_fake_imags:
        ax = fig.add_subplot(8, 8, eval('%d' % i))
        # plt.xticks([]), plt.yticks([])  # 去除坐标轴
        plt.axis('off')
        plt.imshow(image.permute(1, 2, 0))
        i += 1
    plt.subplots_adjust(left=None,  # the left side of the subplots of the figure
                        right=None,  # the right side of the subplots of the figure
                        bottom=None,  # the bottom of the subplots of the figure
                        top=None,  # the top of the subplots of the figure
                        wspace=0.05,  # the amount of width reserved for blank space between subplots
                        hspace=0.05)  # the amount of height reserved for white space between subplots)
    plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)
    plt.show()


if __name__ == '__main__':
    transform = tv.transforms.Compose([
        tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecated
        tv.transforms.CenterCrop(96),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数
    ])

    dataset = tv.datasets.ImageFolder(dir, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'

    print('数据加载完毕!')

    Generator = NetGenerator()
    Discriminator = NetDiscriminator()

    optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    true_labels = Variable(torch.ones(batch_size))     # batch_size
    fake_labels = Variable(torch.zeros(batch_size))
    fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))
    noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布

    if torch.cuda.is_available() == True:
        print('Cuda is available!')
        Generator.cuda()
        Discriminator.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()


    plot_epoch = [1,5,10,20,50,100,199]

    for i in range(200):        # 最大迭代次数
        train()
        print('迭代次数:{}'.format(i))
        if i in plot_epoch:
            show(i)

4. 结果展示与分析

在第1,5,10,20,50,100,199分别打印结果如下所示,这里第0代没有打印:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 刚开始训练的图像比较模糊(1个epoch),但是可以看出图像已经有面部轮廓
  • 继续训练数个epoch之后,生成的图多了很多细节信息,包括头发、颜色等,但是总体还是模糊
  • 训练数个epoch之后,细节继续完善,包括头发的纹理、眼睛的细节等,但还是有不少涂抹的痕迹
  • 训练数个epoch时,已经能看出明显的面部轮廓和细节,但还是有涂抹现象,并且有些细节不够合理,例如眼睛一大一小,面部轮廓扭曲严重
  • 当训练到最大epoch会后,图片的细节已经十分完善,线条更加流畅,轮廓更清晰,虽然还有一些不合理之处,但是已经有不少图片能够以假乱真了

类似的生成动漫头像的项目还有《用DRGAN生成高清的动漫头像》,效果很好,但遗憾的是,由于论文中使用的数据涉及版权问题,未能公开。这篇论文主要改进包括使用了更高质量的图片和更深、更复杂的模型

GAN可以应用到不同的生成图片场景中,只要将训练图片改成其他类型的图片即可,例如LSUN房客图片集、MNIST手写数据集或CIFAR10数据集等。事实上,上述模型还有很大的改进空间。在这里,我们使用的全卷积网络只有四层,模型比较浅,而在ResNet的论文发表之后,也有不少研究者尝试在GAN的网络结构中引入Residual Block结构,并取得了不错的视觉效果。感兴趣可以尝试将示例代码中的单层卷积改为Residual Block,相信可以取得不错的效果

今年来,GAN的一个重大突破在于理论研究。论文《Towards Principled Methods for Training Generative Adversarial Networks》从理论的角度分析了GAN为何难以训练,作者随后在另一篇论文《Wasserstein GAN》中针对性地提出了一个更好的解决方案。但是这篇论文在部分技术细节上的实现过于随意,所以随后又有人有针对性地提出了《Improved Training of Wasserstein GANs》,更好地训练WGAN。后面两篇论文分别用PyTorch和TensorFlow实现,代码可以在GitHub上搜索到。笔者当初也尝试用100行左右的代码实现了Wasserstein GAN,该兴趣可以去了解

随着GAN研究的逐渐成熟,人们也尝试把GAN用于工业实际问题之中,而在众多相关论文中,最令人深刻的就是《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks》,论文中提出了一种新的GAN结构称为CycleGAN。CycleGAN利用GAN实现风格迁移、黑白图像彩色化,以及马和斑马互相转化等,效果十分出众。论文的作者用PyTorch实现了所有的代码,并开源在GitHub上,感兴趣可以自行查阅

小结

GAN生成的结果还是比较理想吧,就是一个简单的GAN的结构,其中的Generator的反卷积与Discriminator的卷积可以琢磨一下,模型只是简单的训练,并没有保存和测试

;