SDE in diffusion models
参考:https://www.bilibili.com/video/BV19M411z7hS/
论文:Score-Based Generative Modeling through Stochastic Differential Equations
本文被认为是 diffusion models 方向中最重要的一篇的论文。作者通过将之前的两类 diffusion models(DDPM、NCSN)通过 SDE 在理论上统一在了一起。注意虽然统一了两种形式,但是本质上 SDE 还是通过估计 score。
为什么要用 SDE 来描述扩散模型?
对于扩散模型,我们的研究对象就是噪声图 x t x_t xt 。对于 x t x_t xt ,实际就是一个关于 x x x 和 t t t 的函数:
- 当 t t t 固定时, x t x_t xt 可看做是一个随机变量, x t ∼ N ( α ˉ t x 0 , 1 − α ˉ t I ) x_t\sim\mathcal{N}(\sqrt{\bar\alpha_t}x_0,\sqrt{1-\bar\alpha_t}I) xt∼N(αˉtx0,1−αˉtI)
- 当 x x x 固定时,我们就得到一条采样轨迹: x T , x T − 1 , … , x 0 x_T,x_{T-1},\dots,x_0 xT,xT−1,…,x0
基于上述观察,可以发现 x t x_t xt 实际上就是一个随机过程,而描述随机过程的数学工具,自然就是 SDE。
连续的扩散过程
之前的工作,无论是 DDPM 还是 NCSN(score-based model),都是将 diffusion 过程看作是一个离散的过程:
离散 diffusion 过程: x 0 , x 1 , . . . x t , . . . , x T x_0,x_1,...x_t,...,x_T x0,x1,...xt,...,xT,reverse 过程: x T , x T − 1 , . . . x t , . . . , x 0 x_T,x_{T-1},...x_t,...,x_0 xT,xT−1,...xt,...,x0 。
实际当然我们也可以将它考虑为一个连续的过程:
连续 diffusion 过程: t ∈ [ 0 , 1 ] t\in[0,1] t∈[0,1] ,扩散过程是 x t → x t + Δ t , Δ t → 0 x_t\rightarrow x_{t+{\Delta t}},\ \ \Delta t\rightarrow 0 xt→xt+Δt, Δt→0 ,reverse 过程则是 x t + Δ t → x t x_{t+{\Delta t}}\rightarrow x_t xt+Δt→xt
在 SDE 的框架内,我们使用更一般的连续形式来研究扩散过程。
SDE 框架描述扩散模型的公式
论文中给出了扩散过程的 SDE 公式:
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
dx=f(x,t)dt+g(t)dw
dx=f(x,t)dt+g(t)dw
式中,
x
x
x 就是我们的研究对象噪声图,
f
(
x
,
t
)
f(x,t)
f(x,t) 称为漂移系数(drift coefficient),
g
(
t
)
g(t)
g(t) 称为扩散系数(diffusion coefficient),
w
w
w 是布朗运动( brownian motion)。
SDE 公式共有两项,其中前一项 f ( x , t ) d t f(x,t)dt f(x,t)dt 是一个确定性的变化过程,后一项 g ( t ) d w g(t)dw g(t)dw 则是一个不确定性的过程。
下面,我们基于 SDE 形式的扩散公式,来推到其重建公式。
我们先把
d
t
dt
dt 写成离散形式
Δ
t
\Delta t
Δt ,有:
x
t
+
Δ
t
−
x
t
=
f
(
x
t
,
t
)
Δ
t
+
g
(
t
)
Δ
t
ϵ
,
ϵ
∼
N
(
0
,
I
)
x
t
+
Δ
t
=
x
t
+
f
(
x
t
,
t
)
Δ
t
+
g
(
t
)
Δ
t
ϵ
x_{t+\Delta t}-x_t=f(x_t,t)\Delta t+g(t)\sqrt{\Delta t}\epsilon,\ \ \ \epsilon\sim\mathcal{N}(0,I) \\ x_{t+\Delta t}=x_t+f(x_t,t)\Delta t+g(t)\sqrt{\Delta t}\epsilon
xt+Δt−xt=f(xt,t)Δt+g(t)Δtϵ, ϵ∼N(0,I)xt+Δt=xt+f(xt,t)Δt+g(t)Δtϵ
注意这里
d
w
dw
dw 的系数从
g
(
t
)
g(t)
g(t) 变成
Δ
t
\sqrt{\Delta t}
Δt ,从一阶的两变成了一个
1
/
2
1/2
1/2 阶的量,这是一个结论性的东西,先记住就行了。
然后写出
P
(
x
t
+
Δ
t
∣
x
t
)
P(x_{t+\Delta t}|x_t)
P(xt+Δt∣xt) :
P
(
x
t
+
Δ
t
∣
x
t
)
∼
N
(
x
t
+
f
(
x
t
,
t
)
Δ
t
,
g
2
(
t
)
Δ
t
I
)
P(x_{t+\Delta t}|x_t)\sim\mathcal{N}(x_t+f(x_t,t)\Delta t,\ \ g^2(t)\Delta t\ I)
P(xt+Δt∣xt)∼N(xt+f(xt,t)Δt, g2(t)Δt I)
上述是对扩散公式的变换,是由
x
t
x_t
xt 推导
x
t
+
Δ
t
x_{t+\Delta t}
xt+Δt ,而我们推导重建公式的目标是由
x
t
+
Δ
t
x_{t+\Delta t}
xt+Δt 推导
x
t
x_t
xt ,即
P
(
x
t
∣
x
t
+
Δ
t
)
P(x_t|x_{t+\Delta t})
P(xt∣xt+Δt) ,有:
P
(
x
t
∣
x
t
+
Δ
t
)
=
P
(
x
t
+
Δ
t
∣
x
t
)
P
(
x
t
)
P
(
x
t
+
Δ
t
)
P(x_t|x_{t+\Delta t})=\frac{P(x_{t+\Delta t}|x_t)P(x_t)}{P(x_{t+\Delta t})}
P(xt∣xt+Δt)=P(xt+Δt)P(xt+Δt∣xt)P(xt)
上面扩散过程中
P
(
x
t
+
Δ
t
∣
x
t
)
P(x_{t+\Delta t}|x_t)
P(xt+Δt∣xt) 已经写出来了,现在要做的就是处理
P
(
x
t
)
P(x_t)
P(xt) 和
P
(
x
t
+
Δ
t
)
P(x_{t+\Delta t})
P(xt+Δt) ,我们将他们构造成 log 的形式:
P
(
x
t
∣
x
t
+
Δ
t
)
=
P
(
x
t
+
Δ
t
∣
x
t
)
exp
{
log
P
(
x
t
)
−
log
P
(
x
t
+
Δ
t
)
}
P(x_t|x_{t+\Delta t})=P(x_{t+\Delta t}|x_t)\exp \{\log P(x_t)-\log P(x_{t+\Delta t})\}
P(xt∣xt+Δt)=P(xt+Δt∣xt)exp{logP(xt)−logP(xt+Δt)}
对式中的
log
P
(
x
t
+
Δ
t
)
\log P(x_{t+\Delta t})
logP(xt+Δt) ,对其进行一阶泰勒展开的近似:
log
(
x
t
+
Δ
t
)
≈
log
P
(
x
t
)
+
(
x
t
+
Δ
t
−
x
t
)
∇
x
t
log
P
(
x
t
)
+
Δ
t
∂
∂
t
log
P
(
x
t
)
\log (x_{t+\Delta t})\approx \log P(x_t)+(x_{t+\Delta t}-x_t)\nabla_{x_t}\log P(x_t)+\Delta t\frac{\partial}{\partial t}\log P(x_t)
log(xt+Δt)≈logP(xt)+(xt+Δt−xt)∇xtlogP(xt)+Δt∂t∂logP(xt)
代回去,有:
P
(
x
t
∣
x
t
+
Δ
t
)
=
P
(
x
t
+
Δ
t
∣
x
t
)
exp
{
−
(
x
t
+
Δ
t
−
x
t
)
∇
x
t
log
P
(
x
t
)
−
Δ
t
∂
∂
t
log
P
(
x
t
)
}
∝
exp
{
∣
∣
x
t
+
Δ
t
−
x
t
−
f
(
x
t
,
t
)
Δ
t
∣
∣
2
2
2
g
2
(
t
)
Δ
t
−
(
x
t
+
Δ
t
−
x
t
)
∇
x
t
log
p
(
x
t
)
−
Δ
t
∂
∂
t
log
P
(
x
t
)
}
\begin{align} P(x_t|x_{t+\Delta t})&=P(x_{t+\Delta t}|x_t)\exp \{-(x_{t+\Delta t}-x_t)\nabla_{x_t}\log P(x_t)-\Delta t\frac{\partial}{\partial t}\log P(x_t)\} \\ &\propto \exp\{\frac{||x_{t+\Delta t}-x_t-f(x_t,t)\Delta t||_2^2}{2g^2(t)\Delta t}-(x_{t+\Delta t}-x_t)\nabla_{x_t}\log p(x_t)-\Delta t\frac{\partial}{\partial t}\log P(x_t)\} \end{align}
P(xt∣xt+Δt)=P(xt+Δt∣xt)exp{−(xt+Δt−xt)∇xtlogP(xt)−Δt∂t∂logP(xt)}∝exp{2g2(t)Δt∣∣xt+Δt−xt−f(xt,t)Δt∣∣22−(xt+Δt−xt)∇xtlogp(xt)−Δt∂t∂logP(xt)}
接下来,我们对
x
t
+
Δ
t
x_{t+\Delta t}
xt+Δt 进行配方:
P
(
x
t
∣
x
t
+
Δ
t
)
∝
exp
{
−
1
2
g
2
(
t
)
Δ
t
[
(
x
t
+
Δ
t
−
x
t
)
2
−
(
2
f
(
x
t
,
t
)
Δ
t
−
2
g
2
(
t
)
Δ
t
∇
x
t
log
p
(
x
t
)
)
(
x
t
+
Δ
t
−
x
t
)
]
}
−
Δ
t
∂
∂
t
log
P
(
x
t
)
−
f
2
(
x
t
,
t
)
Δ
t
2
g
2
(
t
)
=
exp
{
−
1
2
g
2
(
t
)
Δ
t
∣
∣
(
x
t
+
Δ
t
)
−
x
t
−
(
f
(
x
t
,
t
)
−
g
2
(
t
)
∇
x
t
log
P
(
x
t
)
)
Δ
t
∣
∣
2
2
−
Δ
∂
∂
t
log
P
(
x
t
)
−
f
2
(
x
t
,
t
)
Δ
t
2
g
2
(
t
)
+
f
(
x
t
,
t
)
−
g
2
(
t
)
∇
x
t
log
P
(
x
t
)
2
g
2
(
t
)
Δ
t
}
\begin{align} P(x_t|x_{t+\Delta t})&\propto\exp\{-\frac{1}{2g^2(t)\Delta t}[(x_{t+\Delta t}-x_t)^2-(2f(x_t,t)\Delta t-2g^2(t)\Delta t\nabla_{x_t}\log p(x_t))(x_{t+\Delta t}-x_t)]\}-\Delta t\frac{\partial}{\partial t}\log P(x_t)-\frac{f^2(x_t,t)\Delta t}{2g^2(t)} \\ &=\exp \{-\frac{1}{2g^2(t)\Delta t}||(x_{t+\Delta t)}-x_t-(f(x_t,t)-g^2(t)\nabla_{x_t}\log P(x_t))\Delta t||_2^2-\Delta \frac{\partial}{\partial t}\log P(x_t)-\frac{f^2(x_t,t)\Delta t}{2g^2(t)}+\frac{f(x_t,t)-g^2(t)\nabla_{x_t}\log P(x_t)}{2g^2(t)\Delta t}\} \end{align}
P(xt∣xt+Δt)∝exp{−2g2(t)Δt1[(xt+Δt−xt)2−(2f(xt,t)Δt−2g2(t)Δt∇xtlogp(xt))(xt+Δt−xt)]}−Δt∂t∂logP(xt)−2g2(t)f2(xt,t)Δt=exp{−2g2(t)Δt1∣∣(xt+Δt)−xt−(f(xt,t)−g2(t)∇xtlogP(xt))Δt∣∣22−Δ∂t∂logP(xt)−2g2(t)f2(xt,t)Δt+2g2(t)Δtf(xt,t)−g2(t)∇xtlogP(xt)}
接下来做一下近似,我们的
Δ
t
→
0
\Delta t\rightarrow 0
Δt→0,可以将
Δ
t
\Delta t
Δt 的一次项丢掉,而
t
=
t
+
Δ
t
t=t+\Delta t
t=t+Δt ,则:
P
(
x
t
∣
x
t
+
Δ
t
)
∝
exp
{
−
1
2
g
2
(
t
+
Δ
t
)
∣
∣
(
x
t
+
Δ
t
−
x
t
)
−
f
(
x
t
+
Δ
t
,
t
+
Δ
t
)
−
g
2
(
t
+
Δ
t
)
∇
x
t
+
Δ
t
log
P
(
x
t
+
Δ
t
)
Δ
t
∣
∣
2
2
}
P(x_t|x_{t+\Delta t})\propto\exp \{-\frac{1}{2g^2(t+\Delta t)}||(x_{t+\Delta t}-x_t)-f(x_{t+\Delta t},t+\Delta t)-g^2(t+\Delta t)\nabla_{x_{t+\Delta t}}\log P(x_{t+\Delta t})\Delta t||^2_2\}
P(xt∣xt+Δt)∝exp{−2g2(t+Δt)1∣∣(xt+Δt−xt)−f(xt+Δt,t+Δt)−g2(t+Δt)∇xt+ΔtlogP(xt+Δt)Δt∣∣22}
从而我们就可以写出
P
(
x
t
∣
x
t
+
Δ
t
)
P(x_t|x_{t+\Delta t})
P(xt∣xt+Δt) 的均值:
x
t
+
Δ
t
−
(
f
(
x
t
+
Δ
t
,
t
+
Δ
t
)
−
g
2
(
t
+
Δ
t
)
∇
x
t
+
Δ
t
log
P
(
x
t
+
Δ
t
)
)
Δ
t
x_{t+\Delta t}-(f(x_{t+\Delta t},t+\Delta t)-g^2(t+\Delta t)\nabla_{x_{t+\Delta t}}\log P(x_{t+\Delta t}))\Delta t
xt+Δt−(f(xt+Δt,t+Δt)−g2(t+Δt)∇xt+ΔtlogP(xt+Δt))Δt
方差:
g
2
(
t
+
Δ
t
)
Δ
t
g^2(t+\Delta t)\Delta t
g2(t+Δt)Δt
从而:
d
x
=
[
f
(
x
,
t
)
−
g
2
(
t
)
∇
x
t
log
P
(
x
t
)
]
d
t
+
g
(
t
)
d
w
dx=[f(x,t)-g^2(t)\nabla_{x_t}\log P(x_t)]dt+g(t)dw
dx=[f(x,t)−g2(t)∇xtlogP(xt)]dt+g(t)dw
即得到了连续的扩散模型采样过程的 SDE 公式。
当然,代码实现的时候我们还是按照离散的形式来实现,写出其离散形式为:
x
t
+
Δ
t
−
x
t
=
[
f
(
x
t
+
Δ
t
,
t
+
Δ
t
)
−
g
2
(
t
+
Δ
t
)
∇
x
t
+
Δ
t
log
P
(
x
t
+
Δ
t
)
]
Δ
t
+
g
(
t
+
Δ
t
)
Δ
t
ϵ
x_{t+\Delta t}-x_t=[f(x_{t+\Delta t},t+\Delta t)-g^2(t+\Delta t)\nabla_{x_{t+\Delta t}}\log P(x_{t+\Delta t})]\Delta t+g(t+\Delta t)\sqrt{\Delta t}\ \epsilon
xt+Δt−xt=[f(xt+Δt,t+Δt)−g2(t+Δt)∇xt+ΔtlogP(xt+Δt)]Δt+g(t+Δt)Δt ϵ
结果中,只有 score 的部分
∇
x
t
+
Δ
t
log
P
(
x
t
+
Δ
t
)
\nabla_{x_{t+\Delta t}}\log P(x_{t+\Delta t})
∇xt+ΔtlogP(xt+Δt) 是我们不知道的,需要模型来预测,别的部分都是已知的。
如果写成 DDPM 那样的离散形式,即
Δ
t
=
1
\Delta t=1
Δt=1,有:
x
t
−
1
=
x
t
−
[
f
(
x
t
,
t
)
−
g
2
(
t
)
∇
x
t
log
P
(
x
t
)
]
×
1
+
g
(
t
)
ϵ
x_{t-1}=x_t-[f(x_t,t)-g^2(t)\nabla_{x_t}\log P(x_t)]\times 1+g(t)\epsilon
xt−1=xt−[f(xt,t)−g2(t)∇xtlogP(xt)]×1+g(t)ϵ
至此,我们已经有了 SDE 框架下扩散模型的扩散公式和采样公式:
扩散公式:
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
dx=f(x,t)dt+g(t)dw
dx=f(x,t)dt+g(t)dw
采样公式:
d
x
=
[
f
(
x
,
t
)
−
g
2
(
t
)
∇
x
log
P
(
x
)
]
d
t
+
g
(
t
)
d
w
dx=[f(x,t)-g^2(t)\nabla_x\log P(x)]dt+g(t)dw
dx=[f(x,t)−g2(t)∇xlogP(x)]dt+g(t)dw
SDE框架下的NCSN(VE)和DDPM(VP)
我们最开始的时候提到,SDE 这篇工作统一了扩散模型中的 DDPM 和 NCSN 两类模型。本节将具体介绍 SDE 框架是如何统一这两类模型的。
在文章中,NCSN 和 DDPM 分别对应于 VE-SDE 和 VP-SDE,其意义分别为 Variance Exploding 和 Variance Preseving。这里我们先写出两类模型的扩散公式:
方法 | NCSN | DDPM |
---|---|---|
SDE | VE | VP |
一步到位扩散公式 | x t = x 0 + σ t ϵ x_t=x_0+\sigma_t\epsilon xt=x0+σtϵ | x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon xt=αˉtx0+1−αˉtϵ |
单步扩散公式 | x t + 1 = σ t + 1 2 − σ t 2 ϵ x_{t+1}=\sqrt{\sigma_{t+1}^2-\sigma_t^2}\epsilon xt+1=σt+12−σt2ϵ | x t + 1 = 1 − β t + 1 x t + β t + 1 ϵ x_{t+1}=\sqrt{1-\beta_{t+1}}x_t+\sqrt{\beta_{t+1}}\epsilon xt+1=1−βt+1xt+βt+1ϵ |
这两个名字是从何得来呢?我们知道,在扩散模型中,加到最大的噪声强度时,噪声图 x T x_T xT 需要几乎完全是一个高斯分布。
-
在 NCSN 中, x T = x 0 + σ T ϵ x_T=x_0+\sigma_T\epsilon xT=x0+σTϵ 要是一个完全的噪声,其中方差 σ T \sigma_T σT 就要非常大,故称为 Variance Exploding,噪声爆炸;
-
在 DDPM 中, x T = α ˉ T x 0 + 1 − α ˉ T ϵ x_T=\sqrt{\bar\alpha_T}x_0+\sqrt{1-\bar\alpha_T}\epsilon xT=αˉTx0+1−αˉTϵ 要是一个完全的噪声,其中 α ˉ T \bar\alpha_T αˉT 就要非常小,所以噪声 1 − α ˉ t \sqrt{1-\bar\alpha_t} 1−αˉt 最大也只有 1,故称为 Variance Preserving,方差收紧。
接下来我们就正式介绍如何使用 SDE 框架统一这两种扩散模型。我们先再写出 SDE 的扩散公式:
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
x
t
+
Δ
t
=
x
t
+
f
(
x
t
,
t
)
Δ
t
+
g
(
t
)
Δ
t
ϵ
dx=f(x,t)dt+g(t)dw \\ x_{t+\Delta t}=x_t+f(x_t,t)\Delta t+g(t)\sqrt{\Delta t}\ \epsilon
dx=f(x,t)dt+g(t)dwxt+Δt=xt+f(xt,t)Δt+g(t)Δt ϵ
我们先看 DDPM(VE)的情况,有:
x
t
+
Δ
t
=
x
t
+
σ
t
+
Δ
t
2
−
σ
t
2
ϵ
=
x
t
+
σ
t
+
Δ
t
2
−
σ
t
2
Δ
t
Δ
t
ϵ
=
x
t
+
Δ
σ
t
2
Δ
t
Δ
t
ϵ
\begin{align} x_{t+\Delta t}&=x_t+\sqrt{\sigma^2_{t+\Delta t}-\sigma^2_t}\ \epsilon \\ &=x_t+\sqrt{\frac{\sigma^2_{t+\Delta t}-\sigma^2_t}{\Delta t}}\sqrt{\Delta t}\ \epsilon \\ &=x_t+\sqrt{\frac{\Delta \sigma_t^2}{\Delta t}}\sqrt{\Delta t}\ \epsilon \end{align}
xt+Δt=xt+σt+Δt2−σt2 ϵ=xt+Δtσt+Δt2−σt2Δt ϵ=xt+ΔtΔσt2Δt ϵ
与上式比对,则有:
f
(
x
,
t
)
=
0
g
(
t
)
=
d
d
t
σ
t
2
f(x,t)=0\\ g(t)=\frac{d}{dt}\sigma_t^2
f(x,t)=0g(t)=dtdσt2
VE 的情况还是比较简单的。
再来看 DDPM(VP) 的情况,其单步扩散公式为:
x
t
+
1
=
1
−
β
t
+
1
x
t
+
β
t
ϵ
x_{t+1}=\sqrt{1-\beta_{t+1}}x_t+\sqrt{\beta_t}\ \epsilon
xt+1=1−βt+1xt+βt ϵ
其中
{
β
i
}
i
=
1
T
\{\beta_i\}_{i=1}^T
{βi}i=1T 是我们人为定义的一组超参数。为了将离散函数连续化,现在我们再构造
{
β
ˉ
i
=
T
β
i
}
i
=
1
T
\{\bar\beta_i=T\beta_i\}_{i=1}^T
{βˉi=Tβi}i=1T 。当
T
→
∞
T\rightarrow \infty
T→∞ 时,
β
\beta
β 就趋近于一个函数:
{
β
ˉ
i
}
i
=
1
T
→
β
(
t
)
,
t
∈
[
0
,
1
]
\{\bar\beta_i\}_{i=1}^T\rightarrow \beta(t),\ \ t\in [0,1]
{βˉi}i=1T→β(t), t∈[0,1]
则有:
β
(
i
T
)
=
β
ˉ
i
\beta(\frac{i}{T})=\bar\beta_i
β(Ti)=βˉi
再看扩散公式:
x
t
+
1
=
1
−
β
ˉ
t
+
1
T
x
t
+
β
ˉ
t
+
1
T
ϵ
x_{t+1}=\sqrt{1-\frac{\bar\beta_{t+1}}{T}}x_t+\sqrt{\frac{\bar\beta_{t+1}}{T}}\ \epsilon
xt+1=1−Tβˉt+1xt+Tβˉt+1 ϵ
当
Δ
t
=
1
T
\Delta t=\frac{1}{T}
Δt=T1 时:
x
t
+
Δ
t
=
1
−
β
(
t
+
Δ
t
)
x
t
+
β
(
t
+
Δ
t
)
Δ
t
ϵ
x_{t+\Delta t}=\sqrt{1-\beta(t+\Delta t)}x_t+\sqrt{\beta(t+\Delta t)\Delta t}\ \epsilon
xt+Δt=1−β(t+Δt)xt+β(t+Δt)Δt ϵ
这里做一下近似。我们知道,当
x
→
0
x\rightarrow 0
x→0 时,有
(
1
−
x
)
a
≈
1
−
a
x
(1-x)^a\approx 1-ax
(1−x)a≈1−ax,则有:
x
t
+
Δ
t
≈
(
1
−
1
2
β
(
t
+
Δ
t
)
Δ
t
)
x
t
+
β
(
t
+
Δ
t
)
Δ
t
ϵ
x_{t+\Delta t}\approx (1-\frac{1}{2}\beta(t+\Delta t)\Delta t)x_t+\sqrt{\beta(t+\Delta t)}\sqrt{\Delta t}\ \epsilon
xt+Δt≈(1−21β(t+Δt)Δt)xt+β(t+Δt)Δt ϵ
又由于
Δ
t
→
0
\Delta t\rightarrow 0
Δt→0,我们再做一下近似,把不想要的
Δ
t
\Delta t
Δt 拿掉,有:
x
t
+
Δ
t
≈
(
1
−
1
2
β
(
t
)
Δ
t
)
x
t
+
β
(
t
)
Δ
t
ϵ
x_{t+\Delta t}\approx (1-\frac{1}{2}\beta(t)\Delta t)x_t+\sqrt{\beta(t)}\sqrt{\Delta t}\ \epsilon
xt+Δt≈(1−21β(t)Δt)xt+β(t)Δt ϵ
再对比 SDE 的扩散公式,有:
f
(
x
,
t
)
=
−
1
2
β
(
t
)
x
t
g
(
t
)
=
β
(
t
)
f(x,t)=-\frac{1}{2}\beta(t)x_t\\ g(t)=\sqrt{\beta(t)}
f(x,t)=−21β(t)xtg(t)=β(t)
到这里,我们已经把 NCSN 和 DDPM 在 SDE 框架中的形式推导出来了。通过指定不同的
f
(
x
,
t
)
f(x,t)
f(x,t) 和
g
(
t
)
g(t)
g(t) ,扩散模型的 SDE 框架就分别变成了 NCSN 和 DDPM 两类模型。
关于NCSN中的score estimator和DDPM中的denoiser之间关系的分析
在 NCSN 中,我们使用一个模型 s θ ( x t , t ) s_\theta(x_t,t) sθ(xt,t) 来预测模型的分数,而在 DDPM 中,则是使用一个模型 ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) ϵθ(xt,t) 来估计模型的噪声。既然 SDE 框架下,NCSN 能与 DDPM 统一起来,那么我们实际训练的这两个模型,是否也存在一定的关系呢?
s
θ
(
x
t
,
t
)
s_\theta(x_t,t)
sθ(xt,t) 是对分数的估计,即:
s
θ
(
x
t
,
t
)
≈
∇
x
t
log
P
(
x
t
)
s_\theta(x_t,t)\approx\nabla_{x_t}\log P(x_t)
sθ(xt,t)≈∇xtlogP(xt)
而我们知道,
x
t
x_t
xt 是服从一个正态分布的,即:
x
t
∼
N
(
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
x_t\sim\mathcal{N}(\sqrt{\bar\alpha_t}x_0,(1-\bar\alpha_t)I)
xt∼N(αˉtx0,(1−αˉt)I)
则有:
P
(
x
t
)
∝
exp
{
−
∣
∣
x
t
−
α
ˉ
t
x
0
∣
∣
2
2
2
(
1
−
α
ˉ
t
)
}
P(x_t)\propto\exp \{-\frac{||x_t-\sqrt{\bar\alpha_t}x_0||_2^2}{2(1-\bar\alpha_t)}\}
P(xt)∝exp{−2(1−αˉt)∣∣xt−αˉtx0∣∣22}
然后我们可以直接把 score 求出来:
s
θ
(
x
t
,
t
)
≈
∇
x
t
log
P
(
x
t
)
=
−
x
t
−
α
ˉ
t
x
0
1
−
α
ˉ
t
s_\theta(x_t,t)\approx\nabla_{x_t}\log P(x_t)=-\frac{x_t-\sqrt{\bar\alpha_t}x_0}{1-\bar\alpha_t}
sθ(xt,t)≈∇xtlogP(xt)=−1−αˉtxt−αˉtx0
我们还知道噪声预测器:
ϵ
θ
(
x
t
,
t
)
=
x
t
−
α
ˉ
t
x
0
1
−
α
ˉ
t
\epsilon_\theta(x_t,t)=\frac{x_t-\sqrt{\bar\alpha_t}x_0}{\sqrt{1-\bar\alpha_t}}
ϵθ(xt,t)=1−αˉtxt−αˉtx0
比较以上两式,有:
s
θ
(
x
t
,
t
)
=
−
1
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
s_\theta(x_t,t)=-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t,t)
sθ(xt,t)=−1−αˉt1ϵθ(xt,t)
可以看到,NCSN 中的分数估计模型和 DDPM 中的噪声预测模型实际就只差了一个系数而已。如果训好了一个 NSCN 的分数估计模型,甚至可以直接乘个系数直接拿到 DDPM 中用。当然,在介绍 score-based model 的时候我们也提到了,估计分数,实际上就是估计噪声,在这里也再一次得到了验证。
关于VE和VP等价性的简单证明
本节对 VP 和 VE 的等价性做简单的证明。
我们从 VE 开始:
x
t
=
x
0
+
σ
t
ϵ
x_t=x_0+\sigma_t\epsilon
xt=x0+σtϵ
同除
1
+
σ
t
2
\sqrt{1+\sigma^2_t}
1+σt2,有:
x
t
1
+
σ
t
2
=
x
0
1
+
σ
t
2
+
σ
t
ϵ
1
+
σ
t
2
\frac{x_t}{\sqrt{1+\sigma^2_t}}=\frac{x_0}{\sqrt{1+\sigma^2_t}}+\frac{\sigma_t\epsilon}{\sqrt{1+\sigma^2_t}}
1+σt2xt=1+σt2x0+1+σt2σtϵ
这里我们记
x
t
′
=
x
t
1
+
σ
t
2
α
ˉ
t
=
1
1
+
σ
t
2
x'_t=\frac{x_t}{\sqrt{1+\sigma^2_t}}\\ \bar\alpha_t=\frac{1}{\sqrt{1+\sigma_t^2}}
xt′=1+σt2xtαˉt=1+σt21
从而
1
−
α
ˉ
t
=
σ
t
1
+
σ
t
2
\sqrt{1-\bar\alpha_t}=\frac{\sigma_t}{\sqrt{1+\sigma_t^2}}
1−αˉt=1+σt2σt
则有:
x
t
′
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
x'_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\ \epsilon
xt′=αˉtx0+1−αˉt ϵ
到这里发现就已经推到了 VP 的形式,从而就完成了 VE 和 VP 等价性的简单证明。
如果在 VE 上做了些理论的分析,那是可以直接照搬到 VP 上去的,因为它们是等价的。