Bootstrap

深度学习实验之GAN生成动漫人物图像

 数据集格式:

images:

tags.csv:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd

#定义数据集包装类,用于读入数据和包装方法
class AnimeFacesDataset(Dataset):
    def __init__(self, img_folder, tags_file, transform=None):
        self.img_folder = img_folder
        self.transform = transform

        # Load tags
        self.tags_df = pd.read_csv(tags_file, delimiter='\t', header=None)

    def __len__(self):
        return len(os.listdir(self.img_folder))

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_folder, f"{idx}.jpg")
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image

# 定义生成器网络
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, img_shape),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

# 定义判别器网络
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(img_shape, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# 定义训练参数
img_shape = 64 * 64 * 3
latent_dim = 100
lr = 0.0002
batch_size = 64
epochs = 100

# 读入数据
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = AnimeFacesDataset(img_folder="C:\\Users\\lx\\Desktop\\extra_data\\images", tags_file="C:\\Users\\lx\\Desktop\\extra_data\\tags.csv", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#初始化G和D
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)

#定义loss函数等等
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
  1. class AnimeFacesDataset(Dataset):: 这是一个自定义的 PyTorch 数据集类,用于加载动漫头像图像数据集。

    • def __init__(self, img_folder, tags_file, transform=None):: 类的初始化函数,接受图像文件夹路径 img_folder、标签文件路径 tags_file 和可选的图像转换函数 transform

      • self.img_folder = img_folder: 存储图像文件夹路径。

      • self.transform = transform: 存储图像转换函数。

      • self.tags_df = pd.read_csv(tags_file, delimiter='\t', header=None): 从标签文件中读取标签数据,并存储在 tags_df 属性中。

    • def __len__(self):: 返回数据集的长度,即图像文件夹中图像的数量。

    • def __getitem__(self, idx):: 根据给定的索引 idx 加载对应位置的图像。

      • img_name = os.path.join(self.img_folder, f"{idx}.jpg"): 构建图像文件路径。

      • image = Image.open(img_name).convert("RGB"): 使用 PIL 库打开图像,并将其转换为 RGB 模式。

      • if self.transform:: 如果存在图像转换函数,则应用该函数。

      • return image: 返回转换后的图像。

  2. class Generator(nn.Module):: 这是生成器网络的定义,用于生成图像。

    • def __init__(self, latent_dim, img_shape):: 类的初始化函数,接受隐变量维度 latent_dim 和图像形状 img_shape

      • 在 self.model 中定义了一个神经网络模型,该模型接受一个隐变量作为输入,并输出一个图像。
    • def forward(self, z):: 前向传播函数,接受隐变量 z 作为输入,并生成相应的图像。

      • img = self.model(z): 通过神经网络模型生成图像。

      • return img: 返回生成的图像。

  3. class Discriminator(nn.Module):: 这是判别器网络的定义,用于判断图像的真实性。

    • def __init__(self, img_shape):: 类的初始化函数,接受图像形状 img_shape

      • 在 self.model 中定义了一个神经网络模型,该模型接受一个图像作为输入,并输出一个标量,表示图像的真实性。
    • def forward(self, img):: 前向传播函数,接受图像 img 作为输入,并输出图像的真实性。

      • img_flat = img.view(img.size(0), -1): 将输入图像展平成一维向量。

      • validity = self.model(img_flat): 通过神经网络模型判断图像的真实性。

      • return validity: 返回判断结果。

  4. 定义了一些训练参数,包括图像形状、隐变量维度、学习率、批量大小和训练周期等。

  5. 设置了图像预处理的转换函数,包括了将图像调整大小、转换为张量并进行归一化等操作。

  6. 创建了数据集实例 dataset 和数据加载器 dataloader,用于加载训练数据。

  7. 初始化了生成器 generator 和判别器 discriminator 网络实例。

  8. 定义了生成对抗网络的损失函数 adversarial_loss 和优化器 optimizer_Goptimizer_D

#训练!
for epoch in range(epochs):
    for i, imgs in enumerate(dataloader):

        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1)
        fake = torch.zeros(imgs.size(0), 1)

        # Configure input
        real_imgs = imgs

        # -----------------
        #  Train 生成器
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn(imgs.size(0), latent_dim)

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train 判别器
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # Print training info
        batches_done = epoch * len(dataloader) + i
        if batches_done % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

# Save models
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
  1. for epoch in range(epochs):: 外层循环遍历训练周期(epochs)。

  2. for i, imgs in enumerate(dataloader):: 内层循环遍历数据加载器中的每个批次。

    • imgs 是一个批次的图像数据。

    • i 是当前批次的索引。

  3. valid = torch.ones(imgs.size(0), 1): 创建一个由1组成的张量,表示真实图像的标签为真实(1)。

    fake = torch.zeros(imgs.size(0), 1): 创建一个由0组成的张量,表示生成的图像的标签为假(0)。

  4. real_imgs = imgs: 将当前批次的真实图像赋值给 real_imgs

  5. 生成器训练部分:

    • optimizer_G.zero_grad(): 清除生成器梯度。

    • z = torch.randn(imgs.size(0), latent_dim): 从标准正态分布中生成随机噪声作为生成器的输入。

    • gen_imgs = generator(z): 使用生成器生成一批图像。

    • g_loss = adversarial_loss(discriminator(gen_imgs), valid): 计算生成器损失,即判别器对生成图像的判别结果与真实标签之间的差异。

    • g_loss.backward(): 反向传播生成器的损失。

    • optimizer_G.step(): 更新生成器的参数。

  6. 判别器训练部分:

    • optimizer_D.zero_grad(): 清除判别器梯度。

    • real_loss = adversarial_loss(discriminator(real_imgs), valid): 计算真实图像的判别器损失。

    • fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake): 计算生成图像的判别器损失。

    • d_loss = (real_loss + fake_loss) / 2: 计算判别器总体损失,即真实图像和生成图像损失的平均值。

    • d_loss.backward(): 反向传播判别器的损失。

    • optimizer_D.step(): 更新判别器的参数。

  7. 打印训练信息:

    • batches_done = epoch * len(dataloader) + i: 计算已完成的批次数量。

    • if batches_done % 100 == 0:: 每训练完成100个批次时打印一次信息。

      • 打印当前训练周期、当前批次、总批次数以及判别器和生成器的损失值。
  8. 循环结束后,保存生成器和判别器的参数到文件中,以便之后的预测或继续训练使用。

#测试网络
import torch
from torchvision.utils import save_image

# 加载保存的生成器模型
generator = Generator(latent_dim, img_shape)
generator.load_state_dict(torch.load("generator.pth"))
generator.eval()  # 设置为评估模式,关闭 dropout 和 batch normalization

# 生成随机噪声向量
num_samples = 10  # 生成图像的数量
z = torch.randn(num_samples, latent_dim)

# 通过生成器生成图像
with torch.no_grad():
    generated_images = generator(z)

# 将生成的图像保存到文件中
save_image(generated_images, "generated_images.png", nrow=5, normalize=True)

# 可选:显示生成的图像
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(num_samples):
    axes[i // 5, i % 5].imshow(generated_images[i].permute(1, 2, 0).cpu().numpy())
    axes[i // 5, i % 5].axis("off")
plt.show()

这段代码是用来测试生成器网络的效果,并将生成的图像保存到文件中。我来解释一下:

  1. 首先,代码加载了保存的生成器模型:

    generator = Generator(latent_dim, img_shape) generator.load_state_dict(torch.load("generator.pth")) generator.eval() # 设置为评估模式,关闭 dropout 和 batch normalization

    这里假设了存在一个名为 "generator.pth" 的文件,其中保存了已经训练好的生成器模型的参数。加载后,通过 generator.eval() 将模型设置为评估模式,这会关闭 dropout 和 batch normalization,以确保生成的图像稳定。

  2. 生成随机噪声向量:

    num_samples = 10 # 生成图像的数量 z = torch.randn(num_samples, latent_dim)

    这里创建了一个大小为 (10, latent_dim) 的随机噪声向量,用于生成图像。latent_dim 是生成器网络的输入向量的维度。

  3. 通过生成器生成图像:

    with torch.no_grad(): generated_images = generator(z)

    使用 torch.no_grad() 上下文管理器来关闭梯度计算,以节省内存并加快计算速度。然后,通过将随机噪声向量 z 传递给生成器,生成一批图像。

  4. 将生成的图像保存到文件中:

    save_image(generated_images, "generated_images.png", nrow=5, normalize=True)

    save_image() 函数将生成的图像保存到名为 "generated_images.png" 的文件中,nrow=5 指定每行显示 5 张图像,normalize=True 将图像像素值标准化到 0 到 1 之间。

  5. 可选:显示生成的图像:

    import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 5, figsize=(15, 6)) for i in range(num_samples): axes[i // 5, i % 5].imshow(generated_images[i].permute(1, 2, 0).cpu().numpy()) axes[i // 5, i % 5].axis("off") plt.show()

    这段代码用 Matplotlib 在图形界面中显示生成的图像。将生成的图像排列成 2 行 5 列的网格,并在每个子图中显示一张图像。

;