Diffusion Transformers (DiTs) - 用Transformer革新Diffusion模型
Diffusion模型近年来在生成式图像任务中大放异彩,尤其是在图像生成质量上超越了GAN等传统方法。通常,Diffusion模型依赖U-Net作为核心架构,这种基于卷积的网络在图像生成任务中表现出色。然而,William Peebles和Saining Xie在他们的论文《Scalable Diffusion Models with Transformers》中提出了一种全新的思路:将U-Net替换为Transformer架构,推出了Diffusion Transformers (DiTs)。这篇博客将为熟悉Diffusion模型和U-Net的读者详细解读DiTs的核心贡献、组成模块、数据流方向,以及它解决的问题和生成图像的过程。
下文中图片来自于原论文:https://arxiv.org/pdf/2212.09748
核心贡献:从U-Net到Transformer的跨越
传统Diffusion模型依赖U-Net处理图像的去噪过程,U-Net通过其编码-解码结构和跳跃连接,很好地捕捉了空间信息。然而,U-Net的卷积本质限制了它的扩展性,尤其是在模型复杂度增加时,计算成本和性能提升的回报逐渐递减。DiTs的提出挑战了这一现状,核心贡献包括:
- Transformer替代U-Net:将Diffusion模型的主干网络从卷积U-Net替换为基于Vision Transformer (ViT) 的架构,利用Transformer的注意力机制和良好的扩展性。
- 可扩展性验证:通过Gflops(每秒十亿次浮点运算)衡量模型复杂度,证明DiTs在增加深度、宽度或输入token数量时,生成质量(用FID衡量)持续提升,且计算效率优于U-Net。
- 性能突破:DiT-XL/2模型在ImageNet类条件生成任务中达到了2.27的FID(256×256分辨率),超越了所有基于U-Net的Diffusion模型,成为新的SOTA。
这意味着Diffusion模型不再局限于卷积架构,可以借鉴Transformer在语言和视觉领域的成功经验,迈向更统一、更强大的生成框架。
组成模块:DiTs的架构设计
DiTs的设计遵循Vision Transformer的基本原则,但针对Diffusion模型的需求进行了调整。以下是其关键组成模块:
-
Patchify(分块):
- 作用:将输入的潜在表示(latent representation)分割成patch序列。
- 细节:输入是一个空间表示(如VAE编码后的32×32×4张量),通过“patchify”层将其转化为一系列token。patch大小(p×p,例如2×2、4×4或8×8)决定了token数量(T = (I/p)²),直接影响Gflops。
- 数据流:从空间表示到token序列,增加了位置编码(sin-cos形式)。
-
DiT Block(Transformer块):
- 作用:处理patch token序列并融合条件信息(如时间步t和类标签c)。
- 变体:论文探索了四种条件融合方式,最终选择了adaLN-Zero:
- In-context Conditioning:将条件嵌入作为额外token加入序列。
- Cross-attention:通过独立的跨注意力层融合条件。
- adaLN:用条件回归自适应层归一化的scale和shift参数。具体可以参考笔者的另一篇博客:《FiLM: Visual Reasoning with a General Conditioning Layer》一种通用的视觉推理条件层方法(代码实现)
- adaLN-Zero:在adaLN基础上增加残差前的缩放参数α,并初始化为零,使块初始为恒等函数。
- 数据流:token序列经过多个DiT块处理,每个块包括多头自注意力、条件融合和前馈网络。
- Transformer Decoder(解码器):
- 作用:将token序列转换回空间表示,预测噪声和协方差。
- 细节:通过线性层将每个token解码为p×p×2C的张量(C为输入通道数),然后重排为原始空间形状。
- 数据流:从token序列到空间噪声预测。
数据流方向
DiTs的数据流可以概括为以下步骤:
- 输入:从VAE编码器获取的潜在表示z(例如32×32×4),这是图像x压缩后的低分辨率版本。
- Patchify:z被分割成T个patch token,加上位置编码,形成初始序列。
- DiT Blocks:序列通过N个DiT块处理,每个块融合时间步t和类标签c的信息,逐步去噪。
- 解码:最终序列通过线性解码器转换为空间表示,输出预测噪声ε和协方差Σ。
- 输出:基于Diffusion的反向过程,从噪声采样生成潜在表示z,再通过VAE解码器D(z)生成图像x。
整个过程与传统Diffusion模型一致,但空间操作被Transformer的序列操作取代。
解决的问题与框架输入输出
解决的问题:
- U-Net的局限性:U-Net的卷积架构在高分辨率和大规模模型下计算效率低,且难以适应跨领域任务。
- 架构统一性:Diffusion模型与其他领域(如NLP和视觉分类)的Transformer趋势脱节,限制了经验共享。
- 扩展性:如何设计一个随着计算资源增加而持续提升性能的Diffusion模型。
提出的框架:
- 输入:潜在表示z(来自VAE编码器)、时间步t、条件信息c(如类标签)。
- 输出:预测的噪声ε和协方差Σ,用于反向扩散过程。
- 生成过程:
- 从纯噪声 x T x_T xT开始(服从标准正态分布)。
- 通过DiT预测每一步的ε和Σ,迭代去噪生成 x t − 1 x_{t-1} xt−1。
- 最终得到潜在表示 z 0 z_0 z0,用VAE解码器生成图像 x 0 x_0 x0。
生成图像的过程
DiTs的图像生成基于Latent Diffusion Models (LDMs)框架:
- 训练阶段:
- 用VAE将图像x编码为z,DiT在潜在空间学习去噪。
- 损失函数结合噪声预测的MSE和协方差的KL散度。
- 采样阶段:
- 从随机噪声 z T z_T zT开始。
- 迭代运行DiT,预测每一步的噪声 ε θ ( z t , t , c ) ε_θ(z_t, t, c) εθ(zt,t,c)。
- 使用classifier-free guidance增强条件一致性(如类标签)。
- 最终 z 0 z_0 z0通过VAE解码器生成高质量图像。
相比U-Net,DiTs通过token化处理和注意力机制更高效地捕捉全局依赖,提升了生成质量和计算效率。
总结
DiTs将Transformer引入Diffusion模型,开辟了一条新的研究路径。它不仅挑战了U-Net的统治地位,还通过可扩展性和性能突破展示了Transformer在生成任务中的潜力。对于熟悉Diffusion模型的读者来说,DiTs提供了一个值得关注的替代方案,未来可能在更大规模模型和跨领域应用(如文本到图像生成)中进一步发光发热。
简单代码实现
以下是基于论文《Scalable Diffusion Models with Transformers》的Diffusion Transformers (DiTs) 的一个简化 PyTorch 实现。我们将重点实现核心组件,包括 Patchify、DiT Block (adaLN-Zero 变体)、Transformer Decoder,并展示如何在潜在空间中进行扩散过程的训练和采样。为了便于理解,会逐步拆解代码并详细解释。
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 1. Patchify 层:将潜在表示分割成 patch token 序列
class Patchify(nn.Module):
def __init__(self, patch_size, in_channels, embed_dim):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
# 将 patch 映射到嵌入维度
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: (B, C, H, W) -> (B, embed_dim, H/p, W/p)
x = self.proj(x)
# 转换为序列形式: (B, embed_dim, T) -> (B, T, embed_dim)
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
return x
# 2. 位置编码(Sin-Cos)
class PositionalEmbedding(nn.Module):
def __init__(self, embed_dim, max_len=1024):
super().__init__()
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, embed_dim)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (B, T, embed_dim)
return x + self.pe[:, :x.size(1), :]
# 3. DiT Block(adaLN-Zero 变体)
class DiTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, embed_dim)
)
# adaLN-Zero 参数回归网络
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 6 * embed_dim) # gamma, beta, alpha (各 2x embed_dim)
)
def forward(self, x, cond):
# x: (B, T, embed_dim), cond: (B, embed_dim)
# adaLN-Zero 条件融合
mod = self.adaLN_modulation(cond) # (B, 6 * embed_dim)
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = mod.chunk(6, dim=-1)
# 自注意力分支
residual = x
x = self.norm1(x)
x = x + gamma1.unsqueeze(1) * x + beta1.unsqueeze(1) # adaLN
x, _ = self.attn(x, x, x)
x = residual + alpha1.unsqueeze(1) * x # 残差缩放
# 前馈网络分支
residual = x
x = self.norm2(x)
x = x + gamma2.unsqueeze(1) * x + beta2.unsqueeze(1) # adaLN
x = self.mlp(x)
x = residual + alpha2.unsqueeze(1) * x # 残差缩放
return x
# 4. Diffusion Transformer (DiT)
class DiT(nn.Module):
def __init__(self, in_channels=4, patch_size=2, embed_dim=1152, depth=28, num_heads=16):
super().__init__()
self.patchify = Patchify(patch_size, in_channels, embed_dim)
self.pos_embed = PositionalEmbedding(embed_dim)
self.blocks = nn.ModuleList([DiTBlock(embed_dim, num_heads) for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
# 输出噪声和协方差的解码器
self.decoder = nn.Linear(embed_dim, patch_size * patch_size * 2 * in_channels)
self.in_channels = in_channels
self.patch_size = patch_size
def forward(self, x, t, cond):
# x: (B, C, H, W), t: (B,), cond: (B, embed_dim)
B, C, H, W = x.shape
x = self.patchify(x) # (B, T, embed_dim)
x = self.pos_embed(x)
# 时间步嵌入(这里简化为线性映射)
t_embed = nn.Linear(1, embed_dim)(t.unsqueeze(-1).float())
cond = t_embed + cond # 融合时间步和条件
# 通过 DiT Blocks
for block in self.blocks:
x = block(x, cond)
# 解码到空间表示
x = self.norm(x)
x = self.decoder(x) # (B, T, p*p*2*C)
x = x.view(B, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size, 2 * C)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, 2 * C, H, W)
noise_pred, cov_pred = x.chunk(2, dim=1) # 分离噪声和协方差
return noise_pred
# 5. 简单扩散过程
class DiffusionProcess:
def __init__(self, timesteps=1000):
self.timesteps = timesteps
self.betas = torch.linspace(1e-4, 0.02, timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
def add_noise(self, x, t):
# 前向加噪
noise = torch.randn_like(x)
alpha_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
return torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise, noise
def sample(self, model, shape, cond, device):
# 反向采样
x = torch.randn(shape, device=device)
for t in reversed(range(self.timesteps)):
t_tensor = torch.full((shape[0],), t, device=device)
noise_pred = model(x, t_tensor, cond)
alpha_t = self.alphas[t]
alpha_cumprod_t = self.alphas_cumprod[t]
x = (x - (1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_t)
if t > 0:
x += torch.randn_like(x) * torch.sqrt(self.betas[t])
return x
# 示例使用
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型和扩散过程
model = DiT(in_channels=4, patch_size=2, embed_dim=1152, depth=28, num_heads=16).to(device)
diffusion = DiffusionProcess(timesteps=1000)
# 模拟输入(潜在表示和条件)
B, C, H, W = 1, 4, 32, 32
x = torch.randn(B, C, H, W, device=device) # 潜在表示
t = torch.randint(0, 1000, (B,), device=device) # 时间步
cond = torch.randn(B, 1152, device=device) # 类条件嵌入
# 前向加噪和预测
noisy_x, true_noise = diffusion.add_noise(x, t)
pred_noise = model(noisy_x, t, cond)
loss = F.mse_loss(pred_noise, true_noise)
print(f"Loss: {loss.item()}")
# 采样生成
sampled_latent = diffusion.sample(model, (B, C, H, W), cond, device)
print(f"Sampled latent shape: {sampled_latent.shape}")
代码详细解释
1. Patchify 层
- 功能:将输入的潜在表示(如32×32×4的张量)分割成patch token序列。
- 实现:使用卷积层(stride等于patch_size)将空间输入投影到嵌入维度,然后重塑为序列形式。
- 参数:
patch_size
:patch的大小(如2×2),决定token数量。in_channels
:输入通道数(如4,来自VAE)。embed_dim
:每个token的维度(如1152)。
- 数据流:(B, C, H, W) → (B, T, embed_dim),T = (H/p)×(W/p)。
- 解释:这是DiTs将空间数据转化为Transformer可处理的序列的关键步骤,类似于ViT。
2. PositionalEmbedding 层
- 功能:为每个token添加位置信息。
- 实现:使用sin-cos函数生成固定位置编码,类似于原始Transformer。
- 参数:
embed_dim
:嵌入维度。max_len
:最大序列长度。
- 数据流:(B, T, embed_dim) → (B, T, embed_dim)。
- 解释:由于Transformer不具备空间感知能力,位置编码确保token之间的相对位置关系。
3. DiTBlock(adaLN-Zero 变体)
- 功能:Transformer块,处理token序列并融合条件信息。
- 实现:
- 自注意力(MultiheadAttention)和前馈网络(MLP)是标准Transformer组件。
- adaLN-Zero通过条件嵌入(如时间步t和类标签c)回归6个参数:
gamma1, beta1
:自注意力前的缩放和偏移。alpha1
:自注意力残差缩放。gamma2, beta2
:前馈网络前的缩放和偏移。alpha2
:前馈网络残差缩放。
- 参数:
embed_dim
:隐藏维度。num_heads
:注意力头数。mlp_ratio
:MLP隐藏层扩展比例。
- 数据流:(B, T, embed_dim) → (B, T, embed_dim)。
- 解释:adaLN-Zero是论文中最佳的条件融合方式,初始化为恒等函数有助于训练稳定性。
4. DiT 主模型
- 功能:完整DiT架构,预测噪声。
- 实现:
- Patchify将输入转化为序列。
- 多层DiTBlock处理序列。
- 线性解码器将序列转换回空间表示,输出噪声和协方差。
- 参数:
in_channels, patch_size, embed_dim, depth, num_heads
:定义模型规模。
- 数据流:(B, C, H, W) → (B, C, H, W)。
- 解释:DiT替换了U-Net,输出噪声预测用于扩散反向过程。
5. DiffusionProcess 类
- 功能:实现扩散的前向加噪和反向采样。
- 实现:
add_noise
:根据时间步t添加高斯噪声。sample
:从噪声开始迭代去噪生成图像。
- 参数:
timesteps
:扩散步数。
- 数据流:
- 加噪:(B, C, H, W) → (B, C, H, W)。
- 采样:(B, C, H, W) → (B, C, H, W)。
- 解释:这是标准的DDPM过程,DiT负责预测每一步的噪声。
使用说明
- 训练:
- 输入潜在表示(从VAE编码器获得)、时间步t和条件c。
- 计算预测噪声与真实噪声的MSE损失。
- 采样:
- 从随机噪声开始,逐步通过DiT去噪。
- 最终输出潜在表示,可用VAE解码器生成图像。
注意事项
- 简化之处:代码省略了VAE部分、复杂的协方差预测和classifier-free guidance,实际应用需补充。
- 硬件需求:DiT-XL/2需要高性能GPU或TPU,示例中参数规模已适配普通硬件。
- 扩展性:可通过调整
depth
、embed_dim
和patch_size
研究模型的扩展性。
这实现了论文的核心思想,将Transformer引入Diffusion模型,展示了从U-Net到DiTs的转变。希望这个实现能帮助你深入理解DiTs的工作原理!
补充更复杂的实现
以下是对前面代码中提到的“简化之处”进行补充的实现,完善了VAE部分、复杂的协方差预测和classifier-free guidance功能。我会逐一添加这些模块,并详细解释其作用和实现方式。最终代码将更贴近论文《Scalable Diffusion Models with Transformers》的完整框架。
补充后的完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 1. VAE 模块(简化的卷积自编码器)
class VAE(nn.Module):
def __init__(self, in_channels=3, latent_channels=4, hidden_dim=128):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim, hidden_dim * 2, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim * 2, latent_channels * 2, 4, 2, 1) # 输出均值和方差
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(latent_channels, hidden_dim * 2, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim, in_channels, 4, 2, 1),
nn.Sigmoid() # 输出范围 [0, 1]
)
self.latent_channels = latent_channels
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# x: (B, 3, H, W) -> (B, 8, H/8, W/8)
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=1)
z = self.reparameterize(mu, logvar) # (B, 4, H/8, W/8)
x_recon = self.decoder(z)
return x_recon, mu, logvar, z
def decode(self, z):
return self.decoder(z)
# 2. Patchify 层(保持不变)
class Patchify(nn.Module):
def __init__(self, patch_size, in_channels, embed_dim):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
return x
# 3. 位置编码(保持不变)
class PositionalEmbedding(nn.Module):
def __init__(self, embed_dim, max_len=1024):
super().__init__()
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
# 4. DiT Block(添加协方差预测支持)
class DiTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, embed_dim)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 6 * embed_dim) # gamma, beta, alpha
)
def forward(self, x, cond):
mod = self.adaLN_modulation(cond)
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = mod.chunk(6, dim=-1)
residual = x
x = self.norm1(x)
x = x + gamma1.unsqueeze(1) * x + beta1.unsqueeze(1)
x, _ = self.attn(x, x, x)
x = residual + alpha1.unsqueeze(1) * x
residual = x
x = self.norm2(x)
x = x + gamma2.unsqueeze(1) * x + beta2.unsqueeze(1)
x = self.mlp(x)
x = residual + alpha2.unsqueeze(1) * x
return x
# 5. Diffusion Transformer(支持协方差预测)
class DiT(nn.Module):
def __init__(self, in_channels=4, patch_size=2, embed_dim=1152, depth=28, num_heads=16):
super().__init__()
self.patchify = Patchify(patch_size, in_channels, embed_dim)
self.pos_embed = PositionalEmbedding(embed_dim)
self.blocks = nn.ModuleList([DiTBlock(embed_dim, num_heads) for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.decoder = nn.Linear(embed_dim, patch_size * patch_size * 2 * in_channels)
self.in_channels = in_channels
self.patch_size = patch_size
# 时间步嵌入
self.time_embed = nn.Sequential(
nn.Linear(1, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim)
)
def forward(self, x, t, cond, return_cov=False):
B, C, H, W = x.shape
x = self.patchify(x)
x = self.pos_embed(x)
# 时间步嵌入
t = t.unsqueeze(-1).float()
t_embed = self.time_embed(t) # (B, embed_dim)
cond = t_embed + cond if cond is not None else t_embed
for block in self.blocks:
x = block(x, cond)
x = self.norm(x)
x = self.decoder(x)
x = x.view(B, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size, 2 * C)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, 2 * C, H, W)
noise_pred, cov_pred = x.chunk(2, dim=1)
if return_cov:
return noise_pred, cov_pred
return noise_pred
# 6. 扩散过程(添加协方差和classifier-free guidance)
class DiffusionProcess:
def __init__(self, timesteps=1000):
self.timesteps = timesteps
self.betas = torch.linspace(1e-4, 0.02, timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def add_noise(self, x, t):
noise = torch.randn_like(x)
alpha_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
noisy_x = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
return noisy_x, noise
def sample(self, model, shape, cond=None, device="cpu", guidance_scale=1.5):
x = torch.randn(shape, device=device)
model.eval()
with torch.no_grad():
for t in reversed(range(self.timesteps)):
t_tensor = torch.full((shape[0],), t, device=device)
# Classifier-free guidance
if cond is not None and guidance_scale > 1.0:
uncond_pred = model(x, t_tensor, None) # 无条件预测
cond_pred = model(x, t_tensor, cond) # 有条件预测
noise_pred = uncond_pred + guidance_scale * (cond_pred - uncond_pred)
else:
noise_pred = model(x, t_tensor, cond)
alpha_t = self.alphas[t]
alpha_cumprod_t = self.alphas_cumprod[t]
sigma_t = self.betas[t]
x = (x - (1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_t)
if t > 0:
x += torch.randn_like(x) * torch.sqrt(sigma_t)
model.train()
return x
# 示例使用
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化 VAE 和 DiT
vae = VAE(in_channels=3, latent_channels=4).to(device)
model = DiT(in_channels=4, patch_size=2, embed_dim=384, depth=12, num_heads=6).to(device) # 缩小模型便于测试
diffusion = DiffusionProcess(timesteps=1000)
# 模拟图像输入
B, C, H, W = 2, 3, 256, 256
x = torch.rand(B, C, H, W, device=device) # 假设输入图像
t = torch.randint(0, 1000, (B,), device=device)
cond = torch.randn(B, 384, device=device) # 类条件嵌入
# VAE 编码
_, _, _, z = vae(x) # (B, 4, 32, 32)
# 前向加噪和训练
noisy_z, true_noise = diffusion.add_noise(z, t)
pred_noise, pred_cov = model(noisy_z, t, cond, return_cov=True)
loss_noise = F.mse_loss(pred_noise, true_noise)
loss_cov = F.mse_loss(pred_cov, torch.zeros_like(pred_cov)) # 简化协方差损失
loss = loss_noise + 0.1 * loss_cov # 结合两种损失
print(f"Loss: {loss.item()}")
# 采样生成
sampled_latent = diffusion.sample(model, (B, 4, 32, 32), cond, device, guidance_scale=1.5)
sampled_image = vae.decode(sampled_latent)
print(f"Sampled image shape: {sampled_image.shape}")
补充部分的详细解释
1. VAE 模块
- 功能:将图像压缩到潜在空间并解码回图像空间,用于Latent Diffusion。
- 实现:
- 编码器:通过卷积层将图像(如256×256×3)压缩为潜在表示(如32×32×4),输出均值和方差。
- 重参数化:从均值和方差采样潜在变量z。
- 解码器:通过转置卷积将潜在表示解码为图像。
- 参数:
in_channels
:输入图像通道数(如3)。latent_channels
:潜在空间通道数(如4)。hidden_dim
:中间层维度。
- 数据流:
- 编码:(B, 3, 256, 256) → (B, 4, 32, 32)。
- 解码:(B, 4, 32, 32) → (B, 3, 256, 256)。
- 解释:VAE使DiT能在低分辨率的潜在空间操作,降低计算成本。论文中使用预训练的VAE,这里简化为一个小型卷积VAE。
2. 复杂的协方差预测
- 功能:DiT不仅预测噪声ε,还预测逆过程的协方差Σ,用于更精确的去噪。
- 实现:
- 在
DiT
的解码器中,输出维度为patch_size * patch_size * 2 * in_channels
,其中一半通道为噪声预测,一半为协方差预测。 forward
方法添加return_cov
参数,返回噪声和协方差。- 训练时,添加协方差损失(这里简化为与零的MSE,实际应使用KL散度)。
- 在
- 数据流:
- 输入:(B, 4, 32, 32)。
- 输出:(B, 4, 32, 32)(噪声)+ (B, 4, 32, 32)(协方差)。
- 解释:论文中协方差预测提高了采样质量,这里简化实现仅作为示例,实际应用需遵循Nichol和Dhariwal的方法(如[36])优化完整KL散度。
3. Classifier-Free Guidance
具体可以参考笔者的另一篇博客:《Classifier-Free Diffusion Guidance》的核心观点与方法
- 功能:通过结合有条件和无条件预测增强生成样本的质量和条件一致性。
- 实现:
- 在
DiffusionProcess.sample
中,分别计算无条件(cond=None)和有条件噪声预测。 - 使用公式:
noise_pred = uncond_pred + guidance_scale * (cond_pred - uncond_pred)
。 guidance_scale
控制指导强度(如1.5)。
- 在
- 训练支持:
- 训练时需随机丢弃条件(cond=None),模拟无条件情况(未在代码中显式实现,需在数据加载时处理)。
- 数据流:
- 输入:(B, 4, 32, 32) 和条件嵌入。
- 输出:增强后的噪声预测。
- 解释:Classifier-free guidance是DiT-XL/2达到SOTA FID(如2.27)的重要技术,通过放大条件信号提升生成质量。
使用说明
- 训练:
- 先训练VAE(用重构损失和KL散度),然后冻结其权重。
- 用潜在表示z训练DiT,损失结合噪声MSE和协方差损失。
- 随机丢弃cond以支持classifier-free guidance。
- 采样:
- 从噪声开始,使用DiT和guidance生成潜在表示。
- 通过VAE解码器生成最终图像。
注意事项
- VAE预训练:实际应用中应使用Stable Diffusion的预训练VAE(如ft-EMA),而非这里的小型VAE。
- 协方差优化:完整实现需计算真实的KL散度损失,而非简单MSE。
- 计算资源:DiT-XL/2(118.6 Gflops)需要TPU或高性能GPU,示例中模型已缩小。
通过这些补充,代码更接近论文的完整框架,涵盖了VAE、协方差预测和classifier-free guidance的核心功能。希望这能帮你全面理解DiTs的实现细节!
后记
2025年3月18日21点29分于上海,在Grok 3大模型辅助下完成。