Bootstrap

GAN论文分析与代码实现

1 Generative Adversarial Nets

GAN原文:https://arxiv.org/pdf/1406.2661.pdf
GAN

1.1 Generative Model

机器学习模型可分为判别模型(Discriminative Model)与生成模型(Generative Model),判别模型一般构建一个决策面对数据进行分类,而生成模型则需要学习原始样本的数据分布。
GAN是一种生成模型,给定无标签的训练数据集 X = { x 1 , x 2 , … , x n } X=\left\{x_1,x_2,\ldots,x_n\right\} X={x1,x2,,xn}。他可以学习训练数据中的数据分布,生成预测数据 X ′ = { x 1 ′ , x 2 ′ , … , x n ′ } X^\prime=\left\{x_1^\prime,x_2^\prime,\ldots,x_n^\prime\right\} X={x1,x2,,xn}

1.2 Introduction of GAN

GAN的组成包括生成器(Generator)和判别器(Discriminator),Generator从噪声中生成数据,将生成数据送入Discriminator,判断生成数据与真实数据的相似程度,从而进一步优化生成器的生成能力和判别器的判断能力。

论文中给出训练算法:

algorithm
训练判别器时,算法中采用了梯度上升,是因为对交叉熵损失去掉了负号。在我们写代码时,由于交叉熵默认带负号,因此等价于梯度下降。

训练生成器时,作者告诉我们,由于一开始生成器的生成能力较弱,因此 D ( G ( z ) ) D(G(z)) D(G(z))容易为0,导致 log ⁡ ( 1 − D ( G ( z ) ) ) \log{(1-D(G(z)))} log(1D(G(z)))趋向于0,造成梯度消失。作者考虑到了这个问题,并提供了一个方法:

在这里插入图片描述
换句话说,我们不去最大化假图和假标签的距离,而是最小化假图和真标签之间的距离。
对于生成器 G G G,我们令 y t r u e = 1 y_{true}=1 ytrue=1 y f a k e = 0 y_{fake}=0 yfake=0,则我们的目标是:
min ⁡ [ L C r o s s E n t r o p y ( D ( G ( z ) ) , y t r u e ) ] = min ⁡ [ − log ⁡ ( D ( G ( z ) ) ) ]    ⟺    max ⁡ [ log ⁡ ( D ( G ( z ) ) ) ] \min{[L_{CrossEntropy}(D(G(z)), y_{true})]}=\min[-\log{(D(G(z)))}]\iff\max[\log(D(G(z)))] min[LCrossEntropy(D(G(z)),ytrue)]=min[log(D(G(z)))]max[log(D(G(z)))]

1.3 a simple theory

从交叉熵看起:
在这里插入图片描述在这里插入图片描述
在这里插入图片描述

2 代码实现

2.1 实验目标

搭建GAN,实现对MNIST数据集的生成,测试效果

2.2 GAN网络结构

Generator生成器,Discriminator判别器。为简化实验,采用简单的三层全连接网络实现。

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_features=100, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=28*28),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.network(X)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=1),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.network(X)

2.3 对抗训练

注意,针对Generator和Discriminator分别进行训练,两个网络的训练在一个循环中,即分别更新D和G。常见的误区是先把D更新好再去更新G,或者先把G更新好再去更新D,这样都是错的,因为GAN是一个对抗训练的过程,D和G水平应该相当,任何一方太强都不行。网络的训练代码如下:

def train():
    for epoch in range(epochs):
        print(f"epoch {epoch+1}\n-----------------")
        for i, (X_real, _) in enumerate(real_dataloader):
            
            length = X_real.shape[0]
            
            y_real = torch.ones(length, 1)
            y_fake = torch.zeros(length, 1)

            X_real, y_real, y_fake = X_real.to(device), y_real.to(device), y_fake.to(device)

            # 1 ------------更新Discriminator--------------

            z_batch = torch.randn(length, 100).to(device)
            X_fake = G(z_batch)

            output_real = D(X_real.reshape(length, 28*28))
            loss_real = bce_loss(output_real, y_real)
            output_fake = D(X_fake)
            loss_fake = bce_loss(output_fake, y_fake)

            loss_D = loss_real + loss_fake

            D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            # 2 -------------更新Generator---------------

            z_batch = torch.randn(length, 100).to(device)
            X_fake = G(z_batch)
            fake_output = D(X_fake)
            loss_G = bce_loss(fake_output, y_real)

            G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            if i % 100 == 0:
                print(f"loss_G: {loss_G.item()}, loss_D: {loss_D.item()}, D(x): {loss_real.item()}, D(G(z)): {loss_fake.item()}")
                loss_G_list.append(loss_G.item())
                loss_D_list.append(loss_D.item())

                Dx_list.append(loss_real.item())
                DGz_list.append(loss_fake.item())
                global iter
                iter += 1
        # 保存最后的权重文件
        torch.save(G.state_dict(), 'model_G.pth')

2.4 生成结果

loss
D
在这里插入图片描述

3 完整代码

完整代码已上传至github,地址:https://github.com/gwcrepo/GAN-MNIST,有帮助记得star

;