import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision.utils import save_image from torchvision import transforms from PIL import Image import os # Generator 定义 class Generator(nn.Module): def __init__(self, noise_dim): super(Generator, self).__init__() self.fc = nn.Sequential( nn.Linear(noise_dim, 1024 * 4 * 4), nn.BatchNorm1d(1024 * 4 * 4), nn.ReLU(True) ) self.deconv_layers = nn.Sequential( nn.ConvTranspose2d(1024, 512, 4, 2, 1), # 4x4 -> 8x8 nn.BatchNorm2d(512), nn.ReLU(True), nn.ConvTranspose2d(512, 256, 4, 2, 1), # 8x8 -> 16x16 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1), # 16x16 -> 32x32 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1), # 32x32 -> 64x64 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 3, 4, 2, 1), # 64x64 -> 128x128 nn.Tanh() ) def forward(self, noise): x = self.fc(noise).view(-1, 1024, 4, 4) x = self.deconv_layers(x) return x # Discriminator 定义 class Discriminator(nn.Module): def __init__(self, input_channels=3): super(Discriminator, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(input_channels, 64, 4, 2, 1), # 128x128 -> 64x64 nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, 4, 2, 1), # 64x64 -> 32x32 nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, 4, 2, 1), # 32x32 -> 16x16 nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, 4, 2, 1), # 16x16 -> 8x8 nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, 8, 1, 0), # 8x8 -> 1x1 ) def forward(self, x): x = self.conv_layers(x) x = x.view(-1, 1) return x # 数据集定义 class TrafficSignDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_paths = [] labels_file_path = os.path.join(root_dir, 'labels.txt') with open(labels_file_path, 'r') as f: lines = f.readlines() for line in lines: img_name, _ = line.strip().split() img_path = os.path.join(root_dir, img_name) self.image_paths.append(img_path) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image # 设置超参数 noise_dim = 100 # 噪声维度 batch_size = 8 # 批大小 lr = 2e-4 # 学习率 num_epochs = 500 # 训练轮数 output_dir = r"C:\Users\sun\Desktop\2024102201\out" # 生成图像保存路径 if not os.path.exists(output_dir): os.makedirs(output_dir) # 初始化模型 G = Generator(noise_dim=noise_dim).to('cuda') D = Discriminator(input_channels=3).to('cuda') # 设置优化器 optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.0, 0.9)) optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.0, 0.9)) # 学习率调度器 scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=50, gamma=0.5) scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5) # 数据预处理 transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载数据集 root_dir = r"C:\Users\sun\Desktop\2024102201\1" dataset = TrafficSignDataset(root_dir=root_dir, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 定义损失函数 def discriminator_hinge_loss(real_outputs, fake_outputs): real_loss = torch.mean(F.relu(1.0 - real_outputs)) fake_loss = torch.mean(F.relu(1.0 + fake_outputs)) return real_loss + fake_loss def generator_hinge_loss(fake_outputs): return -torch.mean(fake_outputs) # 固定噪声用于生成图像 fixed_noise = torch.randn(64, noise_dim).to('cuda') # 训练循环 for epoch in range(num_epochs): for i, real_images in enumerate(dataloader): real_images = real_images.to('cuda') batch_size_current = real_images.size(0) # --------------------- # 训练判别器 # --------------------- optimizer_D.zero_grad() noise = torch.randn(batch_size_current, noise_dim).to('cuda') fake_images = G(noise) real_outputs = D(real_images) fake_outputs = D(fake_images.detach()) d_loss = discriminator_hinge_loss(real_outputs, fake_outputs) d_loss.backward() optimizer_D.step() # --------------------- # 训练生成器 # --------------------- optimizer_G.zero_grad() fake_outputs = D(fake_images) g_loss = generator_hinge_loss(fake_outputs) g_loss.backward() optimizer_G.step() # 更新学习率 scheduler_G.step() scheduler_D.step() # 输出损失 print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}") # 生成并保存图像 with torch.no_grad(): fake_images = G(fixed_noise).detach().cpu() save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch + 1}.png"), nrow=8, normalize=True) # 可选:每隔一定epoch保存一次模型 if (epoch + 1) % 50 == 0: torch.save(G.state_dict(), os.path.join(output_dir, f'generator_epoch_{epoch + 1}.pth')) torch.save(D.state_dict(), os.path.join(output_dir, f'discriminator_epoch_{epoch + 1}.pth'))