Bootstrap

AF3 FourierEmbedding类源码解读

FourierEmbedding 是一个用于扩散条件的傅里叶嵌入类,其核心是将输入的时间步噪声强度或控制参数timestep)转换为高维的周期性特征。

源代码:

class FourierEmbedding(nn.Module):
    """Fourier embedding for diffusion conditioning."""
    def __init__(self, embed_dim):
        super(FourierEmbedding, self).__init__()
        self.embed_dim = embed_dim
        # Randomly generate weight/bias once before training
        self.weight = nn.Parameter(torch.randn((1, embed_dim)))
        self.bias = nn.Parameter(torch.randn((1, embed_dim)))

    def forward(self, t):
        """Compute embeddings"""
        two_pi = torch.tensor(2 * 3.1415, device=t.device, dtype=t.dtype)
        return torch.cos(two_pi * (t * self.weight + self.bias))

类代码解读:

1. 类的功能

该模块的主要目的是通过傅里叶变换,将输入的时间步嵌入到一个周期性的高维特征空间。这种处理方式在扩散模型中尤为重要,因为时间步本身是一个标量(单一数值),通过傅里叶嵌入,模型能够更好地捕获时间的周期性模式。

2. __init__ 方法
def __init__(self, embed_dim):
    super(FourierEmbedding, self).__init__()
    self.embed_dim = embed_dim
    # Randomly generate weight/bias once before training
    self.weight = nn.Parameter(torch.randn((1, embed_dim)))
    self.bias = nn.Parameter(torch.randn((1, embed_dim)))
功能
  • 初始化傅里叶嵌入模块。
  • 生成随机初始化的权重和偏置(weight 和 bias),用于控制傅里叶变换的频率和相位。
重要参数
  • embed_dim:
    • 表示嵌入的维度,即输出特征的大小。
    • 在扩散模型中,较大的 e
;