之前介绍完了图像生成网络GAN和VAE,终于来到了Diffusion。stable diffusion里比较复杂,同时用到了diffusion,VAE,CLIP等模型,这里我们主要着重介绍diffusion网络本身。
2.原理
Diffusion扩散模型从字面上来理解,就是对噪声进行扩散。它一共有两个扩散步骤:
正向扩散:根据预先设定的噪声进度在图像中添加高斯噪声,直到数据分布趋于先验分布
反向扩撒:去除图像中的噪声,本质上是学习逐步恢复原数据分布
正向扩散过程很好理解,每次step都在之前step的图像基础上加上随机的高斯噪声,这样经过多个step之后,图像将会变成完全的一个噪声图像。
反向扩散过程其实就是用UNet网络去预测逆向的高斯噪声,从而使图像去噪。在噪声微小的前提下,逆向的去噪过程也可以等同于预测高斯噪声。
根据马尔可夫链式推导法则,用表示第t个step的图像,表示该数据的概率分布,所以前向扩散方程可以表示为:
其中,表示第t个step的噪声系数。在第二个公式中,为高斯函数的输出,为高斯函数的输入,而为高斯函数的均值,为高斯函数的方差。换言之,第t个step的图像可以从第t-1个step的图像再加上一个均值为,方差为的高斯噪声得到。
为什么均值和方差要设置为和,这一切都是为了后面的参数重整化技巧。
因为在训练中,要对图像加不同step的噪声,由此让网络学习到不同噪声程度的数据。然而在实际训练中,当需要得到step较大的加噪图像时,我们不可能每次都从step=0开始重新加噪,这样时间成本太大。同时在上一篇文章VAE中我们可以知道,必须要将随机数限制在正态函数中,不然没法去反向推导梯度,因此采用了高斯函数被分解为特定均值和方差的正态函数这一方法。在扩散模型中,我们需要对多个高斯噪声进行叠加,有没有一种方法可以把叠加的高斯函数也分解为特定均值和方差的正态函数呢?
参数重整化:高斯函数可以被分解为特定均值和方差的正态函数,公式可以表达为:
因此,第t个step的图像可以表示为:
需要注意的是,当两个高斯分布相加时,满足如下规律:
因此第四行公式可以直接转换为第五行公式。
所以现在我们直接把各个step的算出来就可以了,不用再每个step进行迭代。
在反向噪声扩散中,由于每次加的噪声很小,所以也可以视为高斯分布,使用神经网络UNet进行拟合。这里推导公式比较复杂,可以参考原论文2006.11239.pdf (arxiv.org)
最后,通过KL散度来让正向分布和反向分布尽可能接近。训练和采样流程如下:
3.代码
接下来我们用pytorch来实现Diffusion在MNIST数据集上的生成。
3.1模型
Unet中上采样和下采样模块都基于resblock,同时还有对step进行embedding的全连接层。数据进行下采样之后,再使上采样输出与step embeding向量进行相加,再输入进下一层上采样层中。
class ResidualConvBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
'''
standard ResNet style convolutional block
'''
self.same_channels = in_channels==out_channels
self.is_res = is_res
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_res:
x1 = self.conv1(x)
x2 = self.conv2(x1)
# this adds on correct residual in case channels have increased
if self.same_channels:
out = x + x2
else:
out = x1 + x2
return out / 1.414
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
class UnetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetDown, self).__init__()
'''
process and downscale the image feature maps
'''
layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
'''
process and upscale the image feature maps
'''
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualConvBlock(out_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
]
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
'''
generic one layer FC NN for embedding things
'''
self.input_dim = input_dim
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, self.input_dim)
return self.model(x)
class Unet(nn.Module):
def __init__(self, in_channels, n_feat=256):
super(Unet, self).__init__()
self.in_channels = in_channels
self.n_feat = n_feat
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
self.down1 = UnetDown(n_feat, n_feat)
self.down2 = UnetDown(n_feat, 2 * n_feat)
self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.timeembed1 = EmbedFC(1, 2 * n_feat)
self.timeembed2 = EmbedFC(1, 1 * n_feat)
self.up0 = nn.Sequential(
# nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
nn.GroupNorm(8, 2 * n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
nn.GroupNorm(8, n_feat),
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
)
def forward(self, x, t):
'''
输入加噪图像和对应的时间step,预测反向噪声的正态分布
:param x: 加噪图像
:param t: 对应step
:return: 正态分布噪声
'''
x = self.init_conv(x)
down1 = self.down1(x)
down2 = self.down2(down1)
hiddenvec = self.to_vec(down2)
# embed time step
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
# 将上采样输出与step编码相加,输入到下一个上采样层
up1 = self.up0(hiddenvec)
up2 = self.up1(up1 + temb1, down2)
up3 = self.up2(up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
3.2训练
训练时随机选择step和随机生成正态分布噪声,通过叠加后得到加噪图像,然后将加噪图像和step一起输入进Unet中,得到当前step的预测正态分布噪声,并与真实正态分布噪声计算loss。
def forward(self, x):
"""
训练过程中, 随机选择step和生成噪声
"""
# 随机选择step
_ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
# 随机生成正态分布噪声
noise = torch.randn_like(x) # eps ~ N(0, 1)
# 加噪后的图像x_t
x_t = (
self.sqrtab[_ts, None, None, None] * x
+ self.sqrtmab[_ts, None, None, None] * noise
)
# 将unet预测的对应step的正态分布噪声与真实噪声做对比
return self.loss_mse(noise, self.model(x_t, _ts / self.n_T))
3.3推理&可视化
推理的时候从随机的初始噪声开始,预测当前噪声的上一个step的正态分布噪声,然后根据采样公式得到反向扩散的均值和方差,最后根据重整化公式计算出上一个step的图像。重复多个step后得到最终的去噪图像。
def sample(self, n_sample, size, device):
# 随机生成初始噪声图片 x_T ~ N(0, 1)
x_i = torch.randn(n_sample, *size).to(device)
for i in range(self.n_T, 0, -1):
t_is = torch.tensor([i / self.n_T]).to(device)
t_is = t_is.repeat(n_sample, 1, 1, 1)
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i, t_is)
x_i = x_i[:n_sample]
x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
return x_i
@torch.no_grad()
def visualize_results(self, epoch):
self.sampler.eval()
# 保存结果路径
output_path = 'results/Diffusion'
if not os.path.exists(output_path):
os.makedirs(output_path)
tot_num_samples = self.sample_num
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
out = self.sampler.sample(tot_num_samples, (1, 28, 28), self.device)
save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
完整代码如下:
import torch, time, os
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F
class ResidualConvBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
'''
standard ResNet style convolutional block
'''
self.same_channels = in_channels==out_channels
self.is_res = is_res
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_res:
x1 = self.conv1(x)
x2 = self.conv2(x1)
# this adds on correct residual in case channels have increased
if self.same_channels:
out = x + x2
else:
out = x1 + x2
return out / 1.414
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
class UnetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetDown, self).__init__()
'''
process and downscale the image feature maps
'''
layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
'''
process and upscale the image feature maps
'''
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualConvBlock(out_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
]
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
'''
generic one layer FC NN for embedding things
'''
self.input_dim = input_dim
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, self.input_dim)
return self.model(x)
class Unet(nn.Module):
def __init__(self, in_channels, n_feat=256):
super(Unet, self).__init__()
self.in_channels = in_channels
self.n_feat = n_feat
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
self.down1 = UnetDown(n_feat, n_feat)
self.down2 = UnetDown(n_feat, 2 * n_feat)
self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.timeembed1 = EmbedFC(1, 2 * n_feat)
self.timeembed2 = EmbedFC(1, 1 * n_feat)
self.up0 = nn.Sequential(
# nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
nn.GroupNorm(8, 2 * n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
nn.GroupNorm(8, n_feat),
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
)
def forward(self, x, t):
'''
输入加噪图像和对应的时间step,预测反向噪声的正态分布
:param x: 加噪图像
:param t: 对应step
:return: 正态分布噪声
'''
x = self.init_conv(x)
down1 = self.down1(x)
down2 = self.down2(down1)
hiddenvec = self.to_vec(down2)
# embed time step
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
# 将上采样输出与step编码相加,输入到下一个上采样层
up1 = self.up0(hiddenvec)
up2 = self.up1(up1 + temb1, down2)
up3 = self.up2(up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
class DDPM(nn.Module):
def __init__(self, model, betas, n_T, device):
super(DDPM, self).__init__()
self.model = model.to(device)
# register_buffer 可以提前保存alpha相关,节约时间
for k, v in self.ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.device = device
self.loss_mse = nn.MSELoss()
def ddpm_schedules(self, beta1, beta2, T):
'''
提前计算各个step的alpha,这里beta是线性变化
:param beta1: beta的下限
:param beta2: beta的下限
:param T: 总共的step数
'''
assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 # 生成beta1-beta2均匀分布的数组
sqrt_beta_t = torch.sqrt(beta_t)
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() # alpha累乘
sqrtab = torch.sqrt(alphabar_t) # 根号alpha累乘
oneover_sqrta = 1 / torch.sqrt(alpha_t) # 1 / 根号alpha
sqrtmab = torch.sqrt(1 - alphabar_t) # 根号下(1-alpha累乘)
mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, # \alpha_t
"oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
"sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
"alphabar_t": alphabar_t, # \bar{\alpha_t}
"sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} # 加噪标准差
"sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} # 加噪均值
"mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
}
def forward(self, x):
"""
训练过程中, 随机选择step和生成噪声
"""
# 随机选择step
_ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
# 随机生成正态分布噪声
noise = torch.randn_like(x) # eps ~ N(0, 1)
# 加噪后的图像x_t
x_t = (
self.sqrtab[_ts, None, None, None] * x
+ self.sqrtmab[_ts, None, None, None] * noise
)
# 将unet预测的对应step的正态分布噪声与真实噪声做对比
return self.loss_mse(noise, self.model(x_t, _ts / self.n_T))
def sample(self, n_sample, size, device):
# 随机生成初始噪声图片 x_T ~ N(0, 1)
x_i = torch.randn(n_sample, *size).to(device)
for i in range(self.n_T, 0, -1):
t_is = torch.tensor([i / self.n_T]).to(device)
t_is = t_is.repeat(n_sample, 1, 1, 1)
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i, t_is)
x_i = x_i[:n_sample]
x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
return x_i
class ImageGenerator(object):
def __init__(self):
'''
初始化,定义超参数、数据集、网络结构等
'''
self.epoch = 20
self.sample_num = 100
self.batch_size = 256
self.lr = 0.0001
self.n_T = 400
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.init_dataloader()
self.sampler = DDPM(model=Unet(in_channels=1), betas=(1e-4, 0.02), n_T=self.n_T, device=self.device).to(self.device)
self.optimizer = optim.Adam(self.sampler.model.parameters(), lr=self.lr)
def init_dataloader(self):
'''
初始化数据集和dataloader
'''
tf = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = MNIST('./data/',
train=True,
download=True,
transform=tf)
self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
val_dataset = MNIST('./data/',
train=False,
download=True,
transform=tf)
self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
def train(self):
self.sampler.train()
print('训练开始!!')
for epoch in range(self.epoch):
self.sampler.model.train()
loss_mean = 0
for i, (images, labels) in enumerate(self.train_dataloader):
images, labels = images.to(self.device), labels.to(self.device)
# 将latent和condition拼接后输入网络
loss = self.sampler(images)
loss_mean += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_loss = loss_mean / len(self.train_dataloader)
print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
self.visualize_results(epoch)
@torch.no_grad()
def visualize_results(self, epoch):
self.sampler.eval()
# 保存结果路径
output_path = 'results/Diffusion'
if not os.path.exists(output_path):
os.makedirs(output_path)
tot_num_samples = self.sample_num
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
out = self.sampler.sample(tot_num_samples, (1, 28, 28), self.device)
save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
if __name__ == '__main__':
generator = ImageGenerator()
generator.train()
4. condition代码及结果
如果我们要生成condition条件下的图像,我们需要对condition向量进行embedding后再拼接到unet输入中。
import torch, time, os
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F
class ResidualConvBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
'''
standard ResNet style convolutional block
'''
self.same_channels = in_channels==out_channels
self.is_res = is_res
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_res:
x1 = self.conv1(x)
x2 = self.conv2(x1)
# this adds on correct residual in case channels have increased
if self.same_channels:
out = x + x2
else:
out = x1 + x2
return out / 1.414
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
class UnetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetDown, self).__init__()
'''
process and downscale the image feature maps
'''
layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
'''
process and upscale the image feature maps
'''
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualConvBlock(out_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
]
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
'''
generic one layer FC NN for embedding things
'''
self.input_dim = input_dim
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, self.input_dim)
return self.model(x)
class Unet(nn.Module):
def __init__(self, in_channels, n_feat=256, n_classes=10):
super(Unet, self).__init__()
self.in_channels = in_channels
self.n_feat = n_feat
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
self.down1 = UnetDown(n_feat, n_feat)
self.down2 = UnetDown(n_feat, 2 * n_feat)
self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.timeembed1 = EmbedFC(1, 2 * n_feat)
self.timeembed2 = EmbedFC(1, 1 * n_feat)
self.conditionembed1 = EmbedFC(n_classes, 2 * n_feat)
self.conditionembed2 = EmbedFC(n_classes, 1 * n_feat)
self.up0 = nn.Sequential(
# nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
nn.GroupNorm(8, 2 * n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
nn.GroupNorm(8, n_feat),
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
)
def forward(self, x, c, t):
'''
输入加噪图像和对应的时间step,预测反向噪声的正态分布
:param x: 加噪图像
:param c: contition向量
:param t: 对应step
:return: 正态分布噪声
'''
x = self.init_conv(x)
down1 = self.down1(x)
down2 = self.down2(down1)
hiddenvec = self.to_vec(down2)
# embed time step
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
cemb1 = self.conditionembed1(c).view(-1, self.n_feat * 2, 1, 1)
cemb2 = self.conditionembed2(c).view(-1, self.n_feat, 1, 1)
# 将上采样输出与step编码相加,输入到下一个上采样层
up1 = self.up0(hiddenvec)
up2 = self.up1(cemb1 * up1 + temb1, down2)
up3 = self.up2(cemb2 * up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
class DDPM(nn.Module):
def __init__(self, model, betas, n_T, device):
super(DDPM, self).__init__()
self.model = model.to(device)
# register_buffer 可以提前保存alpha相关,节约时间
for k, v in self.ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.device = device
self.loss_mse = nn.MSELoss()
def ddpm_schedules(self, beta1, beta2, T):
'''
提前计算各个step的alpha,这里beta是线性变化
:param beta1: beta的下限
:param beta2: beta的下限
:param T: 总共的step数
'''
assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1 # 生成beta1-beta2均匀分布的数组
sqrt_beta_t = torch.sqrt(beta_t)
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp() # alpha累乘
sqrtab = torch.sqrt(alphabar_t) # 根号alpha累乘
oneover_sqrta = 1 / torch.sqrt(alpha_t) # 1 / 根号alpha
sqrtmab = torch.sqrt(1 - alphabar_t) # 根号下(1-alpha累乘)
mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, # \alpha_t
"oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
"sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
"alphabar_t": alphabar_t, # \bar{\alpha_t}
"sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}} # 加噪标准差
"sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}} # 加噪均值
"mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
}
def forward(self, x, c):
"""
训练过程中, 随机选择step和生成噪声
"""
# 随机选择step
_ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(self.device) # t ~ Uniform(0, n_T)
# 随机生成正态分布噪声
noise = torch.randn_like(x) # eps ~ N(0, 1)
# 加噪后的图像x_t
x_t = (
self.sqrtab[_ts, None, None, None] * x
+ self.sqrtmab[_ts, None, None, None] * noise
)
# 将unet预测的对应step的正态分布噪声与真实噪声做对比
return self.loss_mse(noise, self.model(x_t, c, _ts / self.n_T))
def sample(self, n_sample, c, size, device):
# 随机生成初始噪声图片 x_T ~ N(0, 1)
x_i = torch.randn(n_sample, *size).to(device)
for i in range(self.n_T, 0, -1):
t_is = torch.tensor([i / self.n_T]).to(device)
t_is = t_is.repeat(n_sample, 1, 1, 1)
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i, c, t_is)
x_i = x_i[:n_sample]
x_i = self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
return x_i
class ImageGenerator(object):
def __init__(self):
'''
初始化,定义超参数、数据集、网络结构等
'''
self.epoch = 20
self.sample_num = 100
self.batch_size = 256
self.lr = 0.0001
self.n_T = 400
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.init_dataloader()
self.sampler = DDPM(model=Unet(in_channels=1), betas=(1e-4, 0.02), n_T=self.n_T, device=self.device).to(self.device)
self.optimizer = optim.Adam(self.sampler.model.parameters(), lr=self.lr)
def init_dataloader(self):
'''
初始化数据集和dataloader
'''
tf = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = MNIST('./data/',
train=True,
download=True,
transform=tf)
self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
val_dataset = MNIST('./data/',
train=False,
download=True,
transform=tf)
self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
def train(self):
self.sampler.train()
print('训练开始!!')
for epoch in range(self.epoch):
self.sampler.model.train()
loss_mean = 0
for i, (images, labels) in enumerate(self.train_dataloader):
images, labels = images.to(self.device), labels.to(self.device)
labels = F.one_hot(labels, num_classes=10).float()
# 将latent和condition拼接后输入网络
loss = self.sampler(images, labels)
loss_mean += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_loss = loss_mean / len(self.train_dataloader)
print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
self.visualize_results(epoch)
@torch.no_grad()
def visualize_results(self, epoch):
self.sampler.eval()
# 保存结果路径
output_path = 'results/Diffusion'
if not os.path.exists(output_path):
os.makedirs(output_path)
tot_num_samples = self.sample_num
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
labels = F.one_hot(torch.Tensor(np.repeat(np.arange(10), 10)).to(torch.int64), num_classes=10).to(self.device).float()
out = self.sampler.sample(tot_num_samples, labels, (1, 28, 28), self.device)
save_image(out, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
if __name__ == '__main__':
generator = ImageGenerator()
generator.train()
但是如果我们只用condition条件,网络可能过拟合后就无法生成非condition条件下的图像。为了同时满足condition和非condition生成,可以采用classifier free guide的方法,即将condition和非condition同时输入进网络同时训练。代码后续有机会补上~