Bootstrap

LSGAN代码-Pytorch-训练自己的数据集

论文地址

LSGAN的论文地址:https://arxiv.org/abs/1611.04076v3

全部代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 简易生成器架构
class Generator(nn.Module):
    def __init__(self, latentdim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latentdim, 256 * 7 * 7),  # 256 * 7 * 7 输入,确保有足够的特征来生成128x128的图像
            nn.ReLU(),
            nn.Unflatten(1, (256, 7, 7)),  # 将 [batchsize, 256 * 7 * 7] 变为 [batch_size, 256, 7, 7]

            # 第一层反卷积
            nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            # 第二层反卷积
;