Score-based Generative Modeling through Stochastic Differenctial Equations
本文是扩散模型/基于分数的生成模型领域最重要的研究工作之一,提出了连续 SDE 形式的生成模型,将之前的都是以噪声扰动为核心思想的 SMLD 和 DDPM 都统一在 SDE 形式下。并指出了与 SDE 对应的 ODE 形式,以及其在精确似然计算、图像编辑、加速采样等方面的优良性质。极大地启发了后来的工作。
背景
本节首先回顾之前的两个生成模型:SMLD 和 DDPM,他们的核心思想都是:噪声扰动。
SMLD
记
p
σ
(
x
~
∣
x
)
:
=
N
(
x
~
;
x
,
σ
2
I
)
p_\sigma(\tilde{\mathbf{x}}|\mathbf{x}):=\mathcal{N}(\tilde{\mathbf{x}};\mathbf{x},\sigma^2\mathbf{I})
pσ(x~∣x):=N(x~;x,σ2I) 为扰动核,而
p
σ
(
x
~
)
:
=
∫
p
data
(
x
)
p
σ
(
x
~
∣
x
)
d
x
p_\sigma(\tilde{\mathbf{x}}):=\int p_\text{data}(\mathbf{x})p_\sigma(\tilde{\mathbf{x}}|\mathbf{x})d\mathbf{x}
pσ(x~):=∫pdata(x)pσ(x~∣x)dx,其中
p
data
(
x
)
p_\text{data}(\mathbf{x})
pdata(x) 表示数据分布。另外有一组正值噪声尺度:
σ
min
=
σ
1
<
σ
2
<
⋯
<
σ
N
=
σ
max
\sigma_\text{min}=\sigma_1<\sigma_2<\dots<\sigma_N=\sigma_\text{max}
σmin=σ1<σ2<⋯<σN=σmax,一般来说
σ
min
\sigma_\text{min}
σmin 需要足够小,使得噪声最小的分布近似于数据分布
p
σ
min
(
x
)
≈
p
data
(
x
)
p_{\sigma_\text{min}}(\mathbf{x})\approx p_\text{data}(\mathbf{x})
pσmin(x)≈pdata(x),而
σ
max
\sigma_\text{max}
σmax 需要足够大,使得最终加最大噪声的分布近似于纯噪声
p
σ
max
(
x
)
≈
N
(
x
,
0.
σ
max
2
I
)
p_{\sigma_\text{max}}(\mathbf{x})\approx\mathcal{N}(\mathbf{x},0.\sigma_\text{max}^2\mathbf{I})
pσmax(x)≈N(x,0.σmax2I)。Song 等人提出训练一个噪声条件得分网络(Noise Condition Score Network,NCSN)
s
θ
(
x
,
σ
)
\mathbf{s_\theta}(\mathbf{x},\sigma)
sθ(x,σ) 来进行图像生成,其训练的目标函数为加权的去噪得分匹配(weighted sum of denoising score matching):
θ
∗
=
arg
min
θ
∑
i
=
1
N
σ
i
2
E
p
data
(
x
)
E
p
σ
i
(
x
~
∣
x
)
[
∣
∣
s
θ
(
x
~
,
σ
i
)
−
∇
x
~
log
p
σ
i
(
x
~
∣
x
)
∣
∣
2
2
]
\theta^*=\arg\min_\theta\sum_{i=1}^N\sigma_i^2\mathbb{E}_{p_\text{data}(\mathbf{x})}\mathbb{E}_{p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})}[||\mathbf{s_\theta}(\tilde{\mathbf{x}},\sigma_i)-\nabla_{\tilde{\mathbf{x}}}\log p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})||_2^2]
θ∗=argθmini=1∑Nσi2Epdata(x)Epσi(x~∣x)[∣∣sθ(x~,σi)−∇x~logpσi(x~∣x)∣∣22]
如果有足够多的数据和足够强的模型,我们能得到最优的 score based model
s
θ
∗
(
x
,
σ
)
\mathbf{s_{\theta^*}}(\mathbf{x},\sigma)
sθ∗(x,σ) ,来准确估计所有的加噪等级
σ
∈
{
σ
i
}
i
=
1
N
\sigma\in\{\sigma_i\}^N_{i=1}
σ∈{σi}i=1N 下加噪分布
p
σ
i
(
x
)
p_{\sigma_i}(\mathbf{x})
pσi(x) 的得分
∇
x
log
p
σ
(
x
)
\nabla_\mathbf{x}\log p_\sigma(\mathbf{x})
∇xlogpσ(x)。
采样时,我们对每个噪声分布
p
σ
i
(
x
)
p_{\sigma_i}(\mathbf{x})
pσi(x) 执行
M
M
M 步朗之万 MCMC 采样:
x
i
m
=
x
i
m
−
1
+
ϵ
i
s
θ
∗
(
x
i
m
−
1
,
σ
i
)
+
2
ϵ
i
z
i
m
,
m
=
1
,
2
,
…
,
M
\mathbf{x}^m_i=\mathbf{x}_i^{m-1}+\epsilon_i\mathbf{s_{\theta^*}}(\mathbf{x}_i^{m-1},\sigma_i)+\sqrt{2\epsilon_i}\mathbf{z}_i^m,\ \ \ m=1,2,\dots,M
xim=xim−1+ϵisθ∗(xim−1,σi)+2ϵizim, m=1,2,…,M
其中
ϵ
i
>
0
\epsilon_i>0
ϵi>0 是步长,
z
i
m
\mathbf{z}_i^m
zim 服从标准高斯分布。上述过程从
i
=
N
i=N
i=N 到
i
=
1
i=1
i=1 迭代执行,第一次朗之万采样的初始值从纯噪声分布中
p
data
(
x
)
p_\text{data}(\mathbf{x})
pdata(x) 中采样得到
x
N
0
∼
N
(
x
;
0
,
σ
max
2
I
)
\mathbf{x}_N^0\sim\mathcal{N}(\mathbf{x};0,\sigma_\text{max}^2\mathbf{I})
xN0∼N(x;0,σmax2I),后续每次采样的取初始值则取前一次朗之万采样的最终步结果,即
x
i
0
=
x
i
+
1
M
,
i
<
N
\mathbf{x}_i^0=\mathbf{x}_{i+1}^{M},\ i<N
xi0=xi+1M, i<N。当对所有
i
i
i 有
M
→
∞
,
ϵ
i
→
0
M\rightarrow\infty,\epsilon_i\rightarrow0
M→∞,ϵi→0 时,最终的采样结果
x
1
M
\mathbf{x}_1^M
x1M (在一定条件下)即大致服从数据分布
p
σ
min
(
x
)
≈
p
data
(
x
)
p_{\sigma_\text{min}}(\mathbf{x})\approx p_\text{data}(\mathbf{x})
pσmin(x)≈pdata(x)。
DDPM
DDPM 考虑了一个正值噪声尺度序列
0
<
β
1
,
β
2
,
…
,
β
N
<
1
0<\beta_1,\beta_2,\dots,\beta_N<1
0<β1,β2,…,βN<1。对任意的数据点
x
0
∼
p
data
(
x
)
\mathbf{x}_0\sim p_\text{data}(\mathbf{x})
x0∼pdata(x),可以构建一个离散的马尔可夫链
{
x
0
,
x
1
,
…
,
x
N
}
\{\mathbf{x}_0,\mathbf{x}_1,\dots,\mathbf{x}_N\}
{x0,x1,…,xN},其中每一步是在上一步的基础上添加一定的高斯噪声
p
(
x
i
∣
x
i
−
1
)
=
N
(
x
i
;
1
−
β
i
x
i
−
1
,
β
i
I
)
p(\mathbf{x}_i|\mathbf{x}_{i-1})=\mathcal{N}(\mathbf{x}_i;\sqrt{1-\beta_i}\mathbf{x}_{i-1},\beta_i\mathbf{I})
p(xi∣xi−1)=N(xi;1−βixi−1,βiI)。根据每一步所加高斯分布的可叠加性,可以直接从
x
0
\mathbf{x}_0
x0 推得
x
i
\mathbf{x}_i
xi:
p
α
i
(
x
i
∣
x
0
)
=
N
(
x
i
;
α
i
x
0
,
(
1
−
α
i
)
I
)
p_{\alpha_i}(\mathbf{x}_i|\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_i;\sqrt{\alpha_i}\mathbf{x}_0,(1-\alpha_i)\mathbf{I})
pαi(xi∣x0)=N(xi;αix0,(1−αi)I),其中
α
i
:
=
∏
j
=
1
i
(
1
−
β
j
)
\alpha_i:=\prod_{j=1}^i(1-\beta_j)
αi:=∏j=1i(1−βj)。与 SMLD 类似,我们将加噪扰动后的数据分布表示为
p
α
i
(
x
~
)
:
=
∫
p
data
(
x
)
p
α
i
(
x
~
∣
x
)
d
x
p_{\alpha_i}(\mathbf{\tilde{x}}):=\int p_\text{data}(\mathbf{x})p_{\alpha_i}(\tilde{\mathbf{x}}|\mathbf{x})d\mathbf{x}
pαi(x~):=∫pdata(x)pαi(x~∣x)dx。预设的噪声尺度序列会保证最后一步的数据
x
N
\mathbf{x}_N
xN 分布大致服从标准高斯分布
N
(
0
,
I
)
\mathcal{N}(0,\mathbf{I})
N(0,I)。DDPM 训练一个参数化网络来构建反向过程的马尔可夫链
p
θ
(
x
i
−
1
∣
x
i
)
=
N
(
x
i
−
1
;
1
1
−
β
i
(
x
i
+
β
i
s
θ
(
x
i
,
i
)
)
,
β
i
I
)
p_\theta(\mathbf{x}_{i-1}|\mathbf{x}_i)=\mathcal{N}(\mathbf{x}_{i-1};\frac{1}{1-\beta_i}(\mathbf{x}_i+\beta_i\mathbf{s_\theta}(\mathbf{x}_i,i)),\beta_i\mathbf{I})
pθ(xi−1∣xi)=N(xi−1;1−βi1(xi+βisθ(xi,i)),βiI)。训练目标是 ELBO 的一个重加权版本:
θ
∗
=
arg
min
θ
∑
i
=
1
N
(
1
−
α
i
)
E
p
data
(
x
)
E
p
α
i
(
x
~
∣
x
)
[
∣
∣
s
θ
(
x
~
,
i
)
−
∇
x
~
log
p
α
i
(
x
~
∣
x
)
∣
∣
2
2
]
\theta^*=\arg\text{min}_\theta\sum_{i=1}^N(1-\alpha_i)\mathbb{E}_{p_\text{data}(\mathbf{x})}\mathbb{E}_{p_{\alpha_i}(\tilde{\mathbf{x}}|\mathbf{x})}[||\mathbf{s_\theta}(\tilde{\mathbf{x}},i)-\nabla_{\tilde{\mathbf{x}}}\log p_{\alpha_i}(\tilde{\mathbf{x}}|\mathbf{x})||_2^2]
θ∗=argminθi=1∑N(1−αi)Epdata(x)Epαi(x~∣x)[∣∣sθ(x~,i)−∇x~logpαi(x~∣x)∣∣22]
在上述目标函数经优化得到最优模型
s
θ
∗
(
x
,
i
)
\mathbf{s_{\theta^*}}(\mathbf{x},i)
sθ∗(x,i),我们可以进行采样。首先随机采样一个标准高斯噪声
x
N
∼
N
(
0
,
I
)
\mathbf{x}_N\sim\mathcal{N}(0,\mathbf{I})
xN∼N(0,I),然后通过以下的反向马尔可夫链来进行生图:
x
i
−
1
=
1
1
−
β
i
(
x
i
+
β
i
s
θ
∗
(
x
i
,
i
)
)
+
β
i
z
i
,
i
=
N
,
N
−
1
,
…
,
1
\mathbf{x}_{i-1}=\frac{1}{1-\beta_i}(\mathbf{x}_i+\beta_i\mathbf{s_{\theta^*}}(\mathbf{x}_i,i))+\sqrt{\beta_i}\mathbf{z}_i,\ \ i=N,N-1,\dots,1
xi−1=1−βi1(xi+βisθ∗(xi,i))+βizi, i=N,N−1,…,1
我们将这种方法称为祖先采样(ancestral sampling),因为这类似于是从 graphical model
∏
i
=
1
N
p
θ
(
x
i
−
1
∣
x
i
)
\prod_{i=1}^Np_\theta(\mathbf{x}_{i-1}|\mathbf{x}_i)
∏i=1Npθ(xi−1∣xi) 中进行祖先采样。这里的公式 3 实际就是 DDPM 原文中的
L
simple
L_\text{simple}
Lsimple,写成这种形式是为了形式上与公式 1 看起来更接近。与公式 1 类似,公式 3 也是一种 denoising score matching 的加权和,这意味着,最优模型
s
θ
∗
(
x
,
i
)
\mathbf{s_{\theta^*}}(\mathbf{x},i)
sθ∗(x,i) 也可以估计噪声扰动数据分布的得分
∇
x
log
p
α
i
(
x
)
\nabla_\mathbf{x}\log p_{\alpha_i}(\mathbf{x})
∇xlogpαi(x)。另外要注意,公式 1、3 中第
i
i
i 项的权重,也就是
σ
i
2
\sigma_i^2
σi2 和
(
1
−
α
i
)
(1-\alpha_i)
(1−αi),都以相同的形式对应于其扰动核:
σ
i
2
∝
1
/
E
[
∣
∣
∇
x
log
p
σ
i
(
x
~
∣
x
)
∣
∣
2
2
]
\sigma^2_i\propto1/\mathbb{E}[||\nabla_\mathbf{x}\log p_{\sigma_i}(\tilde{\mathbf{x}}|\mathbf{x})||_2^2]
σi2∝1/E[∣∣∇xlogpσi(x~∣x)∣∣22] 和
(
1
−
α
i
)
∝
1
/
E
[
∣
∣
∇
x
log
p
α
i
(
x
~
∣
x
)
∣
∣
2
2
]
(1-\alpha_i)\propto1/\mathbb{E}[||\nabla_\mathbf{x}\log p_{\alpha_i}(\tilde{\mathbf{x}}|\mathbf{x})||_2^2]
(1−αi)∝1/E[∣∣∇xlogpαi(x~∣x)∣∣22]。
连续 SDE 形式的 score based model
上述这些方法成功的关键都是使用多种不同的等级的噪声对数据分布进行扰动。本文提出将这个思想推广到无穷多级的噪声等级,随着噪声的增强,扰动数据的分布按照一个 SDE 进行变化。本文思路的整体框架如下图所示。
使用 SDE 对数据进行扰动
我们的目标是构建一个关于连续时间变量
t
∈
[
0
,
T
]
t\in[0,T]
t∈[0,T] 的扩散过程
{
x
(
t
)
}
t
=
0
T
\{\mathbf{x}(t)\}_{t=0}^T
{x(t)}t=0T ,一要满足
x
(
0
)
∼
p
0
\mathbf{x}(0)\sim p_0
x(0)∼p0,这是我们能从数据分布中采样出的样本;二要满足
x
(
T
)
∼
p
T
\mathbf{x}(T)\sim p_T
x(T)∼pT,使得我们有一个便于计算的形式来高效地生成样本。也就是说,需要使得
p
0
p_0
p0 是数据分布,
p
T
p_T
pT 是先验分布。这个扩散过程可以建模为伊藤 SDE 的解:
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
d\mathbf{x}=\mathbf{f}(\mathbf{x},t)dt+g(t)d\mathbf{w}
dx=f(x,t)dt+g(t)dw
其中
w
\mathbf{w}
w 是标准维纳过程(也称为布朗运动),
f
(
⋅
,
t
)
:
R
d
→
R
d
\mathbf{f}(\cdot,t):\mathbb{R}^d\rightarrow\mathbb{R}^d
f(⋅,t):Rd→Rd 是一个向量函数,称为
x
(
t
)
\mathbf{x}(t)
x(t) 的漂移系数(drift coefficient),而
g
(
⋅
)
=
R
→
R
g(\cdot)=\mathbb{R}\rightarrow\mathbb{R}
g(⋅)=R→R 是一个标量函数,称为
x
(
t
)
\mathbf{x}(t)
x(t) 的扩散系数(diffusion coefficient)。这里为了便于表示假设扩散系数为标量函数,并且与
x
\mathbf{x}
x 无关,原文附录 A 中有更一般化的讨论。只要两个系数在状态和时间上全局满足利普希茨条件,这个 SDE 就有唯一解。以下我们记
x
(
t
)
\mathbf{x}(t)
x(t) 的概率密度为
p
t
(
x
)
p_t(\mathbf{x})
pt(x),记从
x
(
s
)
\mathbf{x}(s)
x(s) 到
x
(
t
)
\mathbf{x}(t)
x(t) (
0
≤
s
<
t
≤
T
0\le s<t\le T
0≤s<t≤T)的转移核为
p
s
t
(
x
(
t
)
∣
x
(
s
)
)
p_{st}(\mathbf{x}(t)|\mathbf{x}(s))
pst(x(t)∣x(s))。
一般来说, p T p_T pT 是一个不含任何数据分布 p 0 p_0 p0 相关信息的噪声先验分布,比如一个固定均值和方差的高斯分布。有多种方式(选择不同的漂移系数和扩散系数)来设计式 5 中的 SDE,使其能够将数据分布转换为一个固定的先验分布。后面会基于 SMLD 和 DDPM 介绍他们的连续版本,作为两个 SDE 的例子。
通过反向 SDE 生成样本
从样本
x
(
T
)
∼
p
T
\mathbf{x}(T)\sim p_T
x(T)∼pT 开始,进行反向过程,我们就可以得到生成的数据样本
x
(
0
)
∼
p
0
\mathbf{x}(0)\sim p_0
x(0)∼p0。前人的研究结果已经证明,扩散过程的反向过程同样是一个扩散过程,由以下反向时间的 SDE 给出:
d
x
=
[
f
(
x
,
t
)
−
g
2
(
t
)
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
d\mathbf{x}=[\mathbf{f}(\mathbf{x},t)-g^2(t)\nabla_\mathbf{x}\log p_t(\mathbf{x})]dt+g(t)d\bar{\mathbf{w}}
dx=[f(x,t)−g2(t)∇xlogpt(x)]dt+g(t)dwˉ
其中
w
ˉ
\bar{\mathbf{w}}
wˉ 是时间反向从
T
T
T 到 0 流动的标准维纳过程,
d
t
dt
dt 是一个无穷小负值。可以看到,在反向 SDE 中,我们唯一未知的就是噪声扰动分布的得分
∇
x
log
p
t
(
x
)
\nabla_\mathbf{x}\log p_t(\mathbf{x})
∇xlogpt(x)。只要对于所有时间
t
t
t,我们都能估计其对应分布的得分,我们就能通过上述反向 SDE 得到
p
0
p_0
p0 中的样本。
估计 SDE 中分布的得分
我们可以通过 score matching 训练一个 score based model 来估计分布的得分。为了估计 SDE 中的
∇
x
log
p
t
(
x
)
\nabla_\mathbf{x}\log p_t(\mathbf{x})
∇xlogpt(x),我们可以将公式 1、3 推广到连续时间,训练一个与时间有关的 score based model
s
θ
(
x
,
t
)
\mathbf{s_\theta}(\mathbf{x},t)
sθ(x,t):
θ
∗
=
arg
min
θ
{
λ
(
t
)
E
x
(
0
)
E
x
(
t
)
∣
x
(
0
)
[
∣
∣
s
θ
(
x
(
t
)
,
t
)
−
∇
x
(
t
)
log
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
∣
∣
2
2
]
}
\theta^*=\arg\min_\theta\{\lambda(t)\mathbb{E}_{\mathbf{x}(0)}\mathbb{E}_{\mathbf{x}(t)|\mathbf{x}(0)}[||\mathbf{s_\theta}(\mathbf{x}(t),t)-\nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t)|\mathbf{x}(0))||_2^2]\}
θ∗=argθmin{λ(t)Ex(0)Ex(t)∣x(0)[∣∣sθ(x(t),t)−∇x(t)logp0t(x(t)∣x(0))∣∣22]}
其中
λ
:
[
0
,
T
]
→
R
>
0
\lambda:[0,T]\rightarrow\mathbb{R}_{>0}
λ:[0,T]→R>0 是一个正值加权函数,
t
t
t 从
[
0
,
T
]
[0,T]
[0,T] 中均匀采样,
x
(
0
)
∼
p
0
(
x
)
\mathbf{x}(0)\sim p_0(\mathbf{x})
x(0)∼p0(x),
x
(
t
)
∼
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
\mathbf{x}(t)\sim p_{0t}(\mathbf{x}(t)|\mathbf{x}(0))
x(t)∼p0t(x(t)∣x(0))。假设有足够多的数据和表达能力足够强的模型,上式可得到最优解,记为
s
θ
∗
(
x
,
t
)
\mathbf{s_{\theta^*}}(\mathbf{x},t)
sθ∗(x,t),这是对得分
∇
x
log
p
t
(
x
)
\nabla_\mathbf{x}\log p_t(\mathbf{x})
∇xlogpt(x) 的准确估计。与 SMLD 和 DDPM 一样,我们一般选择
λ
∝
1
/
E
[
∣
∣
∇
x
(
t
)
log
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
∣
∣
2
2
]
\lambda\propto 1/\mathbb{E}[||\nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t)|\mathbf{x}(0))||_2^2]
λ∝1/E[∣∣∇x(t)logp0t(x(t)∣x(0))∣∣22]。这里公式 7 用的是 denoising score matching,实际上用 siliced score matching 和 finite-difference score matching 也都可行。
我们一般需要已知转移核 p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(\mathbf{x}(t)|\mathbf{x}(0)) p0t(x(t)∣x(0)) 来高效地求解公式 7。当 f ( ⋅ , t ) \mathbf{f}(\cdot,t) f(⋅,t) 是仿射的,其转移核总是一个高斯分布,并且其均值和方差都是有闭式解的。对于更一般的 SDEs,我们可以解 Kolmogorov 前向方程,来得到 p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(\mathbf{x}(t)|\mathbf{x}(0)) p0t(x(t)∣x(0))。或者,我们可以对 SDE 进行模拟,来从 p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(\mathbf{x}(t)|\mathbf{x}(0)) p0t(x(t)∣x(0)) 中进行采样,并将 denoising score matching 替换为 sliced score matching,因为后者可以绕过对 ∇ x ( t ) log p 0 t ( x ( t ) ∣ x ( 0 ) ) \nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t)|\mathbf{x}(0)) ∇x(t)logp0t(x(t)∣x(0)) 的计算。
几个例子:VP、VE 和 sub-VP SDE
之前介绍的 SMLD 和 DDPM 两种方法中使用的噪声扰动,可以看做是两种不同 SDE 的离散化形式。以下进行简要介绍,原文附录 B 介绍了更多细节。
当一共使用
N
N
N 个噪声等级,SMLD 的每个扰动核
p
σ
i
(
x
∣
x
0
)
p_{\sigma_i}(\mathbf{x}|\mathbf{x}_0)
pσi(x∣x0) 对应于以下马尔可夫链中
x
i
\mathbf{x}_i
xi 的分布:
x
i
=
x
i
−
1
+
σ
i
2
−
σ
i
−
1
2
z
i
−
1
,
i
=
1
,
…
,
N
\mathbf{x}_i=\mathbf{x}_{i-1}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}\mathbf{z}_{i-1},\ \ \ \ i=1,\dots,N
xi=xi−1+σi2−σi−12zi−1, i=1,…,N
其中
z
i
−
1
∼
N
(
0
,
I
)
\mathbf{z}_{i-1}\sim\mathcal{N}(0,\mathbf{I})
zi−1∼N(0,I),记
σ
0
=
0
\sigma_0=0
σ0=0 来简化符号。当
N
→
∞
N\rightarrow\infty
N→∞ 时,
{
σ
i
}
i
=
1
N
\{\sigma_i\}_{i=1}^N
{σi}i=1N 就成了一个函数
σ
(
t
)
\sigma(t)
σ(t),
z
i
−
1
\mathbf{z}_{i-1}
zi−1 成了
z
(
t
)
\mathbf{z}(t)
z(t),马尔可夫链
{
x
i
}
i
=
1
N
\{\mathbf{x}_i\}_{i=1}^N
{xi}i=1N 成了一个连续的随机过程
{
x
(
t
)
}
t
=
0
1
\{\mathbf{x}(t)\}_{t=0}^1
{x(t)}t=01,它是关于一个连续的时间变量
t
∈
[
0
,
1
]
t\in[0,1]
t∈[0,1] 的函数,而不再是关于离散的整数值
i
i
i。随机过程
{
x
(
t
)
}
t
=
0
1
\{\mathbf{x}(t)\}_{t=0}^1
{x(t)}t=01 由以下 SDE 给出:
d
x
=
d
[
σ
2
(
t
)
]
d
t
d
w
d\mathbf{x}=\sqrt{\frac{d[\sigma^2(t)]}{dt}}d\mathbf{w}
dx=dtd[σ2(t)]dw
类似地,对于 DDPM 中的扰动核
{
p
α
i
(
x
∣
x
0
)
}
i
=
1
N
\{p_{\alpha_i}(\mathbf{x}|\mathbf{x}_0)\}_{i=1}^N
{pαi(x∣x0)}i=1N,其马尔可夫链为:
x
i
=
1
−
β
i
x
i
−
1
+
β
i
z
i
−
1
,
i
=
1
,
…
,
N
\mathbf{x}_i=\sqrt{1-\beta_i}\mathbf{x}_{i-1}+\sqrt{\beta_i}\mathbf{z}_{i-1},\ \ \ \ i=1,\dots,N
xi=1−βixi−1+βizi−1, i=1,…,N
当
N
→
∞
N\rightarrow\infty
N→∞ 时,上式收敛为一个 SDE:
d
x
=
−
1
2
β
(
t
)
x
d
t
+
β
t
d
w
d\mathbf{x}=-\frac{1}{2}\beta(t)\mathbf{x}dt+\sqrt{\beta_t}d\mathbf{w}
dx=−21β(t)xdt+βtdw
推导过程详见原文附录。SMLD 和 DDPM 中使用的噪声扰动分别对应于 SDE 方程(9)和(11)的离散化。有趣的是,SDE 方程(9)在
t
→
∞
t\rightarrow\infty
t→∞ 时是一个方差爆炸的过程,而 SDE 方程(11)在初始分布具有单位方差时,是一个方差固定的 1 的过程。由于这个差异,此后将方程(9)称为方差爆炸(VE)SDE,方程(11)称为方差保持(VP)SDE。
基于 VP SDE,本文提出了一种新的 SDE 形式,其 SDE 方程为:
d
x
=
−
1
2
β
(
t
)
x
d
t
+
β
(
t
)
(
1
−
e
−
2
∫
0
t
β
(
s
)
d
s
)
d
w
d\mathbf{x}=-\frac{1}{2}\beta(t)\mathbf{x}dt+\sqrt{\beta(t)(1-e^{-2\int_0^t\beta(s)ds})}d\mathbf{w}
dx=−21β(t)xdt+β(t)(1−e−2∫0tβ(s)ds)dw
该种形式在似然计算上表现得尤其的好。当使用相同的
β
(
t
)
\beta(t)
β(t) 且初始分布相同时,在中间的每一个时间步,该方程的方差总是被 VP SDE bound 住,因此该种形式称为 sub VP SDE。
由于 VE、VP 和 sub-VP SDE 的漂移系数都是仿射的,所以它们的扰动核 p 0 t ( x ( t ) ∣ x ( 0 ) ) p_{0t}(\mathbf{x}(t)|\mathbf{x}(0)) p0t(x(t)∣x(0)) 都是高斯分布,有封闭解,因此可以使用 denoising score matching 进行高效地求解。
求解反向 SDE
在训练好关于时间 t t t 的 score based model s θ \mathbf{s_\theta} sθ 之后,我们就可以构建出反向 SDE,然后使用数值方法对其进行模拟,从而生成符合数据分布 p 0 p_0 p0 的新的样本。
通用的 SDE 数值求解器
数值求解器可以提供 SDE 的近似轨迹,有许多通用的数值方法来解 SDE,例如 Euler-Maruyama 和随机 Runge-Kutta 方法,这些对应于随机动力学的不同离散化。我们可以将它们应用于逆向时间 SDE 以进行样本生成。
DDPM 的采样方法,祖先采样,实际上对应于逆向时间 VP SDE 的一种特殊离散化(详见原文附录 E)。为新的 SDE 推导祖先采样规则可能并不容易。为了解决这个问题,我们提出了逆向扩散采样器(reverse diffusion samplers),它们对逆向时间 SDE 进行离散化的方式与前向过程相同,因此可以在给定前向离散化的情况下直接推导。作者的实验表明,逆向扩散采样器在 CIFAR-10 上对 SMLD 和 DDPM 模型的性能略优于祖先采样。另外注意,DDPM 这种祖先采样同样可用于 SMLD。
predictor-corrector 采样器
不同于常规的 SDE,我们有一些额外的信息可以用来改进数值求解的结果。由于我们有一个 score-based model s θ ( x , t ) ≈ ∇ x log p t ( x ) \mathbf{s_\theta}(\mathbf{x},t)\approx\nabla_\mathbf{x}\log p_t(\mathbf{x}) sθ(x,t)≈∇xlogpt(x),所以我们可以使用一些 score-based MCMC 方法,比如 Langevin MCMC 和 HMC,来直接从 p t p_t pt 中进行采样,从而对 SDE 数值求解器来对解进行修正。
具体来说,在每一个时间步,SDE 数值求解器给出对下一个时间步样本的估计,扮演预测器(predictor)的角色,然后,score-based MCMC 修正估计样本的边缘分布,扮演修正器(corrector)的角色。这种方法称为 PC 采样器(Predictor-Corrector (PC) Sampler)。PC 采样器可以看作是 SMLD 和 DDPM 采样方法更一般的推广:SMLD 中,可以看作 Predictor 是恒等函数,Corrector 是退火的朗之万动力学采样;DDPM 中,可以看作 Predictor 是祖先采样,Corrector 是恒等函数。
作者测试了 PC 采样器在原离散 SMLD 和 DDPM 目标函数(即公式1、3)上训练出的模型上的采样结果。结果表明 PC 采样器与用固定数量的噪声尺度训练的基于分数的模型能够兼容。下表中对比了不同采样器的性能,其中概率流将在下节中讨论。可以看到,逆向扩散采样器总是优于祖先采样,而仅有校正器的方法(C2000)在相同的计算量下表现不如其他方法(P2000,PC1000)。实际上,为了对应匹配其他采样器的性能,我们需要每个噪声尺度更多的校正器步,因此需要更多的计算。
概率流及其与 Neural ODE 的联系
score based model 使得我们有另一种数值方法来解反向时间 SDE。对于所有的扩散过程,存在一个对应的确定性过程(deterministic process),其轨迹与 SDE 有着相同的边缘概率密度
{
p
t
(
x
)
}
t
=
0
T
\{p_t(\mathbf{x})\}_{{t=0}}^T
{pt(x)}t=0T。这个确定性过程满足如下 ODE:
d
x
=
[
f
(
x
,
t
)
−
1
2
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
d\mathbf{x}=[\mathbf{f}(\mathbf{x},t)-\frac{1}{2}g(t)^2\nabla_\mathbf{x}\log p_t(\mathbf{x})]dt
dx=[f(x,t)−21g(t)2∇xlogpt(x)]dt
只要得分已知,这就可以从 SDE 中确定出来。作者将这种 ODE 称为概率流 ODE(Probablity Flow ODE)。当得分函数是通过一个 score-based model,一般来说是一个神经网络,来估计的,这其实就是一个 Neural ODE。
精确似然计算
利用与 Neural ODE 的联系,我们可以通过即时变量变换公式计算方程(13)中的密度。这使得我们能够对任何输入数据计算精确的似然。作者在 CIFAR-10 数据集上测试了的负对数似然(NLLs)的计算,在下表中。我们计算的是对均匀量化数据的对数似然度,并且只与以相同方式评估的模型进行比较(不包括使用变分量化或离散数据评估的模型),除了 DDPM( L / L simple L/L_\text{simple} L/Lsimple)的 ELBO 值(带 * 标的)是在离散数据上得出的。主要结论是:(i)对于原 DDPM 模型,我们获得了比ELBO 更好的 bits/dim,因为我们的似然度是精确的;(ii)使用相同的架构,我们用方程(7)中的连续目标函数训练了另一个 DDPM 模型(即DDPM cont.),进一步提高了似然;(iii)sub-VP SDE 的似然比 VP SDE 更高;(iv)使用本文改进的网络架构(即DDPM++ cont.)和sub-VP SDE,甚至可以在没有最大似然训练的情况下,在均匀量化的 CIFAR-10 上达到新的记录 bits/dim 为 2.99。
操纵隐层表示
通过方程(13),我们可以将任何数据点 x ( 0 ) \mathbf{x}(0) x(0) 编码到隐空间 x ( T ) \mathbf{x}(T) x(T),解码则可以通过积分相应的反向时间 SDE 的 ODE 来实现。与其他可逆模型如 Neural ODE和 normalizing flows 一样,我们可以操纵这个潜在表示来进行图像编辑,如插值和温度缩放。
唯一可识别的编码
与大多数当前的可逆模型不同,本文 ODE 的编码是唯一可识别的。也就是说,在有足够的训练数据、模型容量和优化精度的情况下,输入的编码由数据分布唯一确定。这是因为前向 SDE,方程(5),没有可训练的参数,其对应的的概率流 ODE,方程(13),在完美估计分数的情况下有相同的轨迹。
高效的采样
与 Neural ODE 一样,我们可以通过解方程(13)从不同的最终条件 x ( T ) ∼ p T \mathbf{x}(T) \sim p_T x(T)∼pT 来采样 x ( 0 ) ∼ p 0 \mathbf{x}(0)\sim p_0 x(0)∼p0。使用固定的离散化策略,我们可以生成有不错的样本,特别是当在加上修正器时。使用黑盒的 ODE 求解器不仅可以产生高质量的样本,而且还允许我们明确地在准确性和效率之间进行权衡。通过更大的容差,函数评估的次数可以减少90%以上,而不影响生成样本的视觉质量。
总结
本文首先回顾了两种以噪声扰动为核心思想的生成模型:SMLD 和 DDPM。然后提出连续 SDE 形式下的 score based model,并证明了 SMLD 和 DDPM 实际是 SDE 形式下的不同离散化。然后介绍了用于求解反向 SDE 的 Predictor-Corrector 的采样方法。最后,指出了每个 SDE 都有其对应的 ODE 形式,在 ODE 形式下,有精确似然计算、操纵隐层表示、高效采样等诸多好处。