Bootstrap

DDPM代码解读

了解了DDPM优雅的理论推导后,开始手撕代码,这样可以更深入理解DDPM背后的思想。
Code:链接

导入必要的包

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F
from PIL import Image

定义辅助函数和类,用以构建Unet

def exists(x):#
    return x is not None #检查x是不是None,如果x不是None,返回True,反之,返回False

def default(val, d):
    if exists(val):
        return val #如果val存在(即不是None),函数就返回val
    return d() if callable(d) else d #如果val不存在(即val是None),函数将尝试返回d。如果d是一个可调用的对象(例如函数或实现了__call__方法的类实例),则调用d()并返回其结果。反之,直接返回d


def num_to_groups(num, divisor): #将一个数 num 分成若干组,每组包含 divisor 个单位,然后返回一个数组,其中包含这些组的值
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder) #如果 remainder 大于0,表示还有剩余的部分,函数会在 arr 的末尾追加一个元素,其值为 remainder。
    return arr

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out = None): #使用最近邻插值方法将输入特征图的尺寸放大。
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'), #scale_factor=2 表示将每个维度的尺寸放大两倍。
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None): #下采样,将hw缩小压入深度即通道维度
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

解读sinusoidal positional embeds

class SinusoidalPositionEmbeddings(nn.Module): #由于神经网络的参数在时间(噪声级别)上是共享的,作者采用正弦位置嵌入来编码t
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :] #根据pytorch中的广播机制生成时间嵌入矩阵
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

正弦位置编码公式:
P E t , i ( 1 ) = sin ⁡ ( t 1000 0 i d − 1 ) P E t , i ( 2 ) = cos ⁡ ( t 1000 0 i d − 1 ) \begin{aligned} & P E_{t, i}^{(1)}=\sin \left(\frac{t}{10000^{\frac{i}{d-1}}}\right) \\ & P E_{t, i}^{(2)}=\cos \left(\frac{t}{10000^\frac{i}{d-1}}\right) \end{aligned} PEt,i(1)=sin(10000d1it)PEt,i(2)=cos(10000d1it)
其中 d d dhalf_dim
以下是上面代码块的公式表达:
emb = ln ⁡ ( 10000 ) d − 1 = ln ⁡ ( 1000 0 1 d − 1 ) \text{emb} = \frac{\ln(\text{10000})}{d - 1}=\ln(10000^{\frac{1}{d-1}}) emb=d1ln(10000)=ln(10000d11)
emb = emb ∗ i = i ∗ ln ⁡ ( 1000 0 1 d − 1 ) = ln ⁡ ( 1000 0 i d − 1 ) \text{emb}=\text{emb}*i =i*\ln(10000^{\frac{1}{d-1}})=\ln(10000^{\frac{i}{d-1}}) emb=embi=iln(10000d11)=ln(10000d1i)
− emb = − ln ⁡ ( 1000 0 i d − 1 ) = ln ⁡ ( 1 1000 0 i d − 1 ) -\text{emb} = -\ln(10000^{\frac{i}{d-1}})=\ln(\frac{1}{10000^{\frac{i}{d-1}}}) emb=ln(10000d1i)=ln(10000d1i1)
e − emb = 1 1000 0 i d − 1 e^{-\text{emb}} =\frac{1}{10000^{\frac{i}{d-1}}} eemb=10000d1i1
emb = e − emb ∗ t = t 1000 0 i d − 1 \text{emb}=e^{-\text{emb}}*t =\frac{t}{10000^{\frac{i}{d-1}}} emb=eembt=10000d1it
sin ⁡ ( emb ) \sin\left(\text{emb}\right) sin(emb)
cos ⁡ ( emb ) \cos\left(\text{emb}\right) cos(emb)

正弦位置嵌入(SinusoidalPositionEmbeddings)模块将形状张量(batch_size, 1)作为输入(即一批图像中若干噪声图像的噪声水平),并将其转化为形状张量(batch_size, dim),其中 dim 是位置嵌入的维度。

定义Unet的核心构建块

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) / (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift): #如果提供了 scale_shift,即尺度和偏移值,则对规范化后的输出进行尺度和偏移变换。
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) #将时间嵌入映射到 dim_out * 2 维度
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb): 
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1) #使用 chunk 将时间嵌入拆分为两个 (batch_size, dim_out, 1, 1) 的向量,分别用于缩放和平移。

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

注意力模块

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5  # 缩放因子,用于缩放查询向量
        self.heads = heads  # 注意力头的数量
        hidden_dim = dim_head * heads  # 隐藏维度
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)  # 生成查询、键和值的卷积层
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)  # 输出卷积层

    def forward(self, x):
        b, c, h, w = x.shape  # 获取输入的形状
        qkv = self.to_qkv(x).chunk(3, dim=1)  # 将输入通过卷积层并拆分成查询、键和值
        q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)  # 重排张量维度以适应多头注意力
        q = q * self.scale  # 缩放查询向量

        sim = einsum("b h d i, b h d j -> b h i j", q, k)  # 计算查询和键的点积相似度
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()  # 减去最大值以稳定数值
        attn = sim.softmax(dim=-1)  # 对相似度进行softmax归一化

        out = einsum("b h i j, b h d j -> b h i d", attn, v)  # 计算加权和值
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)  # 重排输出张量维度
        return self.to_out(out)  # 通过输出卷积层并返回

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5  # 缩放因子,用于缩放查询向量
        self.heads = heads  # 注意力头的数量
        hidden_dim = dim_head * heads  # 隐藏维度
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)  # 生成查询、键和值的卷积层

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))  # 输出卷积层和组归一化层

    def forward(self, x):
        b, c, h, w = x.shape  # 获取输入的形状
        qkv = self.to_qkv(x).chunk(3, dim=1)  # 将输入通过卷积层并拆分成查询、键和值
        q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)  # 重排张量维度以适应多头注意力

        q = q.softmax(dim=-2)  # 对查询向量进行softmax归一化
        k = k.softmax(dim=-1)  # 对键向量进行softmax归一化

        q = q * self.scale  # 缩放查询向量
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)  # 计算键和值的加权和

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)  # 计算加权和值
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)  # 重排输出张量维度
        return self.to_out(out)  # 通过输出卷积层和组归一化层并返回

组归一化,定义一个类,将用于在注意层之前应用组归一化

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn  # 需要应用的函数
        self.norm = nn.GroupNorm(1, dim)  # 定义GroupNorm层,对输入进行规范化

    def forward(self, x):
        x = self.norm(x)  # 对输入进行规范化
        return self.fn(x)  # 对规范化后的输入应用指定的函数

Unet架构

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=False,
        resnet_block_groups=4,
    ):
        super().__init__()

        # 定义输入的通道数和是否使用自条件
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        # 初始化维度
        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # 时间嵌入
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # 定义网络层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

神经网络Unet的工作是接收一批噪声图像和各自的噪声级别,并输出添加到输入的噪声
更直白的说,网络以一批形状为(batch_size, num_channels, height, width)的噪声图像和一批形状为(batch_size, 1)的>噪声级别作为输入,并返回形状为(batch_size, num_channels, height, width)的张量。

定义正向过程
正向扩散过程逐渐向真实图像添加噪声,从真实分布开始,在若干时间步T内进行。这个过程是根据方差计划进行的。原始的DDPM作者采用了线性计划:

“We set the forward process variances to constants increasing linearly from β 1 = 1 0 − 4 \beta_1=10^{-4} β1=104 to β T = 0.02 \beta_T=0.02 βT=0.02.”

# 函数的主要目的是从输入张量 a 中根据索引张量 t 提取适当的 t 索引对应的值,并将其调整为适合后续处理的形状
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t) ## 从张量 a 中按索引 t 提取值
    return out.reshape(b, *((1,) * (len(x_shape) - 1))) ## 将提取的值重塑为指定的形状

def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    """
    scale = 1000 / timesteps #加入 scale 确保 beta 的范围在不同的时间步数下保持一致
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

定义时间步 T T T

timesteps = 1000

Create β 1 , … , β T \beta_1, \ldots, \beta_T β1,,βT linearly increasing variance schedule

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

α t = 1 − β t \alpha_t=1-\beta_t αt=1βt

alphas = 1. - betas

α ˉ t = ∏ s = 1 t α s \bar{\alpha}_t=\prod_{s=1}^t \alpha_s αˉt=s=1tαs

alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)#在计算扩散过程的逆过程时,需要访问前一步的累积乘积值。通过这种方式,可以方便地在同一个序列中访问当前步长和前一步长的累积乘积值。在扩散模型的第一步(时间步 t=0)时,没有前一步的累积乘积值,因此需要一个初始值 1.0。这个 1.0 的含义是,在 t=0 时刻,原始数据保持其自身(因为没有任何扩散),所以乘积是 1。

1 α t \frac{1}{\sqrt{\alpha_t}} αt 1

sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

q ( x t ∣ x 0 ) = N ( x t ; α t ‾ x 0 , ( 1 − α t ‾ ) I ) q\left(x_t \mid x_0\right)=\mathcal{N}\left(x_t ; \sqrt{\overline{\alpha_t}} x_0,\left(1-\overline{\alpha_t}\right) \mathbf{I}\right) q(xtx0)=N(xt;αt x0,(1αt)I)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

β ~ t : = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}_t:=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t β~t:=1αˉt1αˉt1βt

posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

前向过程代码:

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

定义损失函数:

L simple  ( θ ) : = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] L_{\text {simple }}(\theta):=\mathbb{E}_{t, \mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right)\right\|^2\right] Lsimple (θ):=Et,x0,ϵ[ ϵϵθ(αˉt x0+1αˉt ϵ,t) 2]

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

定义数据集:

from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128

from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

# define image transformations (e.g. using torchvision)
transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)

定义反向过程:

μ θ ( x t , t ) = μ ~ t ( x t , 1 α ˉ t ( x t − 1 − α ˉ t ϵ θ ( x t ) ) ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)=\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t\right)\right)\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right) μθ(xt,t)=μ~t(xt,αˉt 1(xt1αˉt ϵθ(xt)))=αt 1(xt1αˉt βtϵθ(xt,t))
σ t = β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \sigma_t=\sqrt{\tilde{\beta}_t}=\sqrt{\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t} σt=β~t =1αˉt1αˉt1βt

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):#时间步迭代采样
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

定义训练的一些参数

from pathlib import Path

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

训练的脚本:

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

from torchvision.utils import save_image

epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)


采样示例:

# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 63
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

在这里插入图片描述

;