数据集格式:
images:
tags.csv:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
#定义数据集包装类,用于读入数据和包装方法
class AnimeFacesDataset(Dataset):
def __init__(self, img_folder, tags_file, transform=None):
self.img_folder = img_folder
self.transform = transform
# Load tags
self.tags_df = pd.read_csv(tags_file, delimiter='\t', header=None)
def __len__(self):
return len(os.listdir(self.img_folder))
def __getitem__(self, idx):
img_name = os.path.join(self.img_folder, f"{idx}.jpg")
image = Image.open(img_name).convert("RGB")
if self.transform:
image = self.transform(image)
return image
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, img_shape),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_shape, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# 定义训练参数
img_shape = 64 * 64 * 3
latent_dim = 100
lr = 0.0002
batch_size = 64
epochs = 100
# 读入数据
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = AnimeFacesDataset(img_folder="C:\\Users\\lx\\Desktop\\extra_data\\images", tags_file="C:\\Users\\lx\\Desktop\\extra_data\\tags.csv", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
#初始化G和D
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)
#定义loss函数等等
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
class AnimeFacesDataset(Dataset):
: 这是一个自定义的 PyTorch 数据集类,用于加载动漫头像图像数据集。
def __init__(self, img_folder, tags_file, transform=None):
: 类的初始化函数,接受图像文件夹路径img_folder
、标签文件路径tags_file
和可选的图像转换函数transform
。
self.img_folder = img_folder
: 存储图像文件夹路径。
self.transform = transform
: 存储图像转换函数。
self.tags_df = pd.read_csv(tags_file, delimiter='\t', header=None)
: 从标签文件中读取标签数据,并存储在tags_df
属性中。
def __len__(self):
: 返回数据集的长度,即图像文件夹中图像的数量。
def __getitem__(self, idx):
: 根据给定的索引idx
加载对应位置的图像。
img_name = os.path.join(self.img_folder, f"{idx}.jpg")
: 构建图像文件路径。
image = Image.open(img_name).convert("RGB")
: 使用 PIL 库打开图像,并将其转换为 RGB 模式。
if self.transform:
: 如果存在图像转换函数,则应用该函数。
return image
: 返回转换后的图像。
class Generator(nn.Module):
: 这是生成器网络的定义,用于生成图像。
def __init__(self, latent_dim, img_shape):
: 类的初始化函数,接受隐变量维度latent_dim
和图像形状img_shape
。
- 在
self.model
中定义了一个神经网络模型,该模型接受一个隐变量作为输入,并输出一个图像。
def forward(self, z):
: 前向传播函数,接受隐变量z
作为输入,并生成相应的图像。
img = self.model(z)
: 通过神经网络模型生成图像。
return img
: 返回生成的图像。
class Discriminator(nn.Module):
: 这是判别器网络的定义,用于判断图像的真实性。
def __init__(self, img_shape):
: 类的初始化函数,接受图像形状img_shape
。
- 在
self.model
中定义了一个神经网络模型,该模型接受一个图像作为输入,并输出一个标量,表示图像的真实性。
def forward(self, img):
: 前向传播函数,接受图像img
作为输入,并输出图像的真实性。
img_flat = img.view(img.size(0), -1)
: 将输入图像展平成一维向量。
validity = self.model(img_flat)
: 通过神经网络模型判断图像的真实性。
return validity
: 返回判断结果。定义了一些训练参数,包括图像形状、隐变量维度、学习率、批量大小和训练周期等。
设置了图像预处理的转换函数,包括了将图像调整大小、转换为张量并进行归一化等操作。
创建了数据集实例
dataset
和数据加载器dataloader
,用于加载训练数据。初始化了生成器
generator
和判别器discriminator
网络实例。定义了生成对抗网络的损失函数
adversarial_loss
和优化器optimizer_G
、optimizer_D
。
#训练!
for epoch in range(epochs):
for i, imgs in enumerate(dataloader):
# Adversarial ground truths
valid = torch.ones(imgs.size(0), 1)
fake = torch.zeros(imgs.size(0), 1)
# Configure input
real_imgs = imgs
# -----------------
# Train 生成器
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = torch.randn(imgs.size(0), latent_dim)
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train 判别器
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# Print training info
batches_done = epoch * len(dataloader) + i
if batches_done % 100 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
# Save models
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
for epoch in range(epochs):
: 外层循环遍历训练周期(epochs)。
for i, imgs in enumerate(dataloader):
: 内层循环遍历数据加载器中的每个批次。
imgs
是一个批次的图像数据。
i
是当前批次的索引。
valid = torch.ones(imgs.size(0), 1)
: 创建一个由1组成的张量,表示真实图像的标签为真实(1)。
fake = torch.zeros(imgs.size(0), 1)
: 创建一个由0组成的张量,表示生成的图像的标签为假(0)。
real_imgs = imgs
: 将当前批次的真实图像赋值给real_imgs
。生成器训练部分:
optimizer_G.zero_grad()
: 清除生成器梯度。
z = torch.randn(imgs.size(0), latent_dim)
: 从标准正态分布中生成随机噪声作为生成器的输入。
gen_imgs = generator(z)
: 使用生成器生成一批图像。
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
: 计算生成器损失,即判别器对生成图像的判别结果与真实标签之间的差异。
g_loss.backward()
: 反向传播生成器的损失。
optimizer_G.step()
: 更新生成器的参数。判别器训练部分:
optimizer_D.zero_grad()
: 清除判别器梯度。
real_loss = adversarial_loss(discriminator(real_imgs), valid)
: 计算真实图像的判别器损失。
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
: 计算生成图像的判别器损失。
d_loss = (real_loss + fake_loss) / 2
: 计算判别器总体损失,即真实图像和生成图像损失的平均值。
d_loss.backward()
: 反向传播判别器的损失。
optimizer_D.step()
: 更新判别器的参数。打印训练信息:
batches_done = epoch * len(dataloader) + i
: 计算已完成的批次数量。
if batches_done % 100 == 0:
: 每训练完成100个批次时打印一次信息。
- 打印当前训练周期、当前批次、总批次数以及判别器和生成器的损失值。
循环结束后,保存生成器和判别器的参数到文件中,以便之后的预测或继续训练使用。
#测试网络
import torch
from torchvision.utils import save_image
# 加载保存的生成器模型
generator = Generator(latent_dim, img_shape)
generator.load_state_dict(torch.load("generator.pth"))
generator.eval() # 设置为评估模式,关闭 dropout 和 batch normalization
# 生成随机噪声向量
num_samples = 10 # 生成图像的数量
z = torch.randn(num_samples, latent_dim)
# 通过生成器生成图像
with torch.no_grad():
generated_images = generator(z)
# 将生成的图像保存到文件中
save_image(generated_images, "generated_images.png", nrow=5, normalize=True)
# 可选:显示生成的图像
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(num_samples):
axes[i // 5, i % 5].imshow(generated_images[i].permute(1, 2, 0).cpu().numpy())
axes[i // 5, i % 5].axis("off")
plt.show()
这段代码是用来测试生成器网络的效果,并将生成的图像保存到文件中。我来解释一下:
首先,代码加载了保存的生成器模型:
generator = Generator(latent_dim, img_shape) generator.load_state_dict(torch.load("generator.pth")) generator.eval() # 设置为评估模式,关闭 dropout 和 batch normalization
这里假设了存在一个名为 "generator.pth" 的文件,其中保存了已经训练好的生成器模型的参数。加载后,通过
generator.eval()
将模型设置为评估模式,这会关闭 dropout 和 batch normalization,以确保生成的图像稳定。生成随机噪声向量:
num_samples = 10 # 生成图像的数量 z = torch.randn(num_samples, latent_dim)
这里创建了一个大小为 (10, latent_dim) 的随机噪声向量,用于生成图像。
latent_dim
是生成器网络的输入向量的维度。通过生成器生成图像:
with torch.no_grad(): generated_images = generator(z)
使用
torch.no_grad()
上下文管理器来关闭梯度计算,以节省内存并加快计算速度。然后,通过将随机噪声向量z
传递给生成器,生成一批图像。将生成的图像保存到文件中:
save_image(generated_images, "generated_images.png", nrow=5, normalize=True)
save_image()
函数将生成的图像保存到名为 "generated_images.png" 的文件中,nrow=5
指定每行显示 5 张图像,normalize=True
将图像像素值标准化到 0 到 1 之间。可选:显示生成的图像:
import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 5, figsize=(15, 6)) for i in range(num_samples): axes[i // 5, i % 5].imshow(generated_images[i].permute(1, 2, 0).cpu().numpy()) axes[i // 5, i % 5].axis("off") plt.show()
这段代码用 Matplotlib 在图形界面中显示生成的图像。将生成的图像排列成 2 行 5 列的网格,并在每个子图中显示一张图像。