Bootstrap

Consistency Models

Consistency Models

本文介绍 Song Yang 大佬提出的 Consistency Models,实现扩散模型一步采样生成,同时支持多步采样以权衡采样质量和生成速度。可通过蒸馏或原生训练两种模式进行训练。

在这里插入图片描述

背景知识:扩散模型

Consistency Models 的设计是基于连续时间扩散模型(SDE/ODE、EDM)。扩散模型通过高斯噪声扰动,逐步将数据转变为纯噪声,并通过从纯噪声图逐步去噪来生成新的样本。具体来说,记数据分布为 p data ( x ) p_\text{data}(\mathbf{x}) pdata(x),扩散模型首先使用 Ito SDE 对其进行扰动:
d x t = μ ( x t , t ) d t + σ ( t ) d w t d\mathbf{x}_t=\mathbf{\mu}(\mathbf{x}_t,t)dt+\sigma(t)d\mathbf{w}_t dxt=μ(xt,t)dt+σ(t)dwt
其中 t ∈ [ 0 , T ] t\in[0,T] t[0,T] T > 0 T>0 T>0 是一个常数, μ ( ⋅ , ⋅ ) \mu(\cdot,\cdot) μ(,) σ ( t ) \sigma(t) σ(t) 分别是漂移系数和扩散系数, { w t } t ∈ [ 0 , T ] \{\mathbf{w}_t\}_{t\in[0,T]} {wt}t[0,T] 表示标准布朗运动。我们记 x t \mathbf{x}_t xt 的分布为 p t ( x ) p_t(\mathbf{x}) pt(x),从而有 p 0 ( x ) ≡ p data ( x ) p_0(\mathbf{x})\equiv p_\text{data}(\mathbf{x}) p0(x)pdata(x)。这个 SDE 有一个很关键的性质,即:存在一个 ODE(称为概率流 ODE,PF ODE),其解轨迹在 t t t 时刻的采样也服从同样的 p t ( x ) p_t(\mathbf{x}) pt(x)。该 ODE 的形式为:
d x t = [ μ ( x t , t ) − 1 2 σ ( t ) 2 ∇ log ⁡ p t ( x t ) ] d t d\mathbf{x}_t=[\mu(\mathbf{x}_t,t)-\frac{1}{2}\sigma(t)^2\nabla\log p_t(\mathbf{x}_t)]dt dxt=[μ(xt,t)21σ(t)2logpt(xt)]dt
其中 log ⁡ p t ( x ) \log p_t(\mathbf{x}) logpt(x) 是分布 p t ( x ) p_t(\mathbf{x}) pt(x) 的得分函数。因此扩散模型同样也可以看作是基于得分的生成模型。可以看到,在 PF ODE 中,只有得分函数 ∇ log ⁡ p t ( x ) \nabla\log p_t(\mathbf{x}) logpt(x) 是未知的,因此我们只需要训练一个得分模型 s ϕ ( x , t ) ≈ ∇ log ⁡ p t ( x ) \mathbf{s_\phi}(\mathbf{x},t)\approx\nabla\log p_t(\mathbf{x}) sϕ(x,t)logpt(x) 来估计各时刻分布的得分,就可以根据 PF ODE 进行采样生成。

一般来说,公式 1 中我们选定的的 SDE 形式需要满足在 t = T t=T t=T 时刻的分布 p T ( x ) p_T(\mathbf{x}) pT(x) 接近于一个便于计算的高斯分布 π ( x ) \pi(\mathbf{x}) π(x)。满足这个条件的 SDE 形式有很多,本文中采用的是 EDM 中的 SDE 形式,即 μ ( x , t ) = 0 \mu(\mathbf{x},t)=0 μ(x,t)=0 σ ( t ) = 2 t \sigma(t)=\sqrt{2t} σ(t)=2t 。在此形式下,各时刻的分布可表示为数据分布与一个高斯分布的卷积的形式: p t ( x ) = p data ( x ) ⨂ N ( 0 , t 2 I ) p_t(\mathbf{x})=p_\text{data}(\mathbf{x})\bigotimes \mathcal{N}(0,t^2\mathbf{I}) pt(x)=pdata(x)N(0,t2I),其中 ⨂ \bigotimes 表示卷积操作, π ( x ) \pi(\mathbf{x}) π(x) 服从高斯分布 π ( x ) ∼ N ( 0 , T 2 I ) \pi(\mathbf{x})\sim\mathcal{N}(0,T^2\mathbf{I}) π(x)N(0,T2I)。在 EDM 的形式下,PF ODE 为:
d x t d t = − t s ϕ ( x t , t ) \frac{d\mathbf{x}_t}{dt}=-t\mathbf{s_\phi}(\mathbf{x}_t,t) dtdxt=tsϕ(xt,t)
我们将上式称为 empirical PF ODE。在采样时,我们首先采样最后一个时刻的 x ^ T ∼ π = N ( 0 , T 2 I ) \hat{\mathbf{x}}_T\sim\pi=\mathcal{N}(0,T^2\mathbf{I}) x^Tπ=N(0,T2I) 来对 empirical PF ODE 进行初始化,然后使用数值 ODE 求解器(如 Euler、Heun 等)按照反向时间来对其进行求解。会得到各时刻解组成的解轨迹 { x ^ t } t ∈ [ 0 , T ] \{\hat{\mathbf{x}}_t\}_{t\in[0,T]} {x^t}t[0,T]。解轨迹的最后一步 x ^ 0 \hat{\mathbf{x}}_0 x^0 就可以认为是大致服从数据分布 p data ( x ) p_\text{data}(\mathbf{x}) pdata(x) 的样本。为了提高数值稳定性,一个常用的技巧是在时间到达一个很小的值 t = ϵ t=\epsilon t=ϵ 处就停止(而不是达到 0),将 x ^ ϵ \hat{\mathbf{x}}_\epsilon x^ϵ 作为最终的生成结果。参考 EDM 中的经验,我们将像素值 rescale 到 [ − 1 , 1 ] [-1,1] [1,1] 之间,并设 T = 80 , ϵ = 0.002 T=80,\epsilon=0.002 T=80,ϵ=0.002​。

目前,扩散模型应用的主要瓶颈在于采样生成速度。可以看到,使用 ODE 求解器进行采样生成需要多次对得分模型进行计算,这样计算开销很大。现在已经有一些方法,从数值 ODE 求解器加速和蒸馏等方面加速扩散模型采样。然而,目前 ODE 求解器再怎么也需要至少 10 步的模型计算,而大多数蒸馏的方法收集扩散模型生图的大规模数据集。目前只有 Progressive Distillation 没有这个限制,这也是本文对比的主要方法。

Consistency Models

本文提出的 Consistency Model,从设计上就支持单步生成,同时也支持多步采样,使得我们能够在采样速度和生成质量上进行灵活权衡。CM 有两种训练模式:蒸馏模式和原生训练模式。在蒸馏模式中,CM 可以对一个预训练的扩散模型进行知识蒸馏,得到一个单步采样器,对比其他蒸馏加速采样的方法,生图质量大大提升,并且还可以进行 zeroshot 图像编辑。在原生训练模式中,不需要先有一个预训练的扩散模型,直接从头训练 CM,这样 CM 相当于是一类全新的生成模型。

以下,我们将从定义、参数化、采样、图像编辑三个方面介绍 Consistency Model,下一章介绍 Consistency Model 的两种训练模式。

定义

给定一个 PF ODE 的解轨迹 { x t } t ∈ [ ϵ , T ] \{\mathbf{x}_t\}_{t\in[\epsilon,T]} {xt}t[ϵ,T],我们定义一个一致性函数 f : ( x t , t ) → x ϵ f:(\mathbf{x}_t,t)\rightarrow \mathbf{x}_\epsilon f:(xt,t)xϵ。一致性函数有一个重要的性质:自一致性,即对于同一个 PF ODE 解轨迹中的任意的输入参数对 ( x t , t ) (\mathbf{x}_t,t) (xt,t),其输出是一致的。自一致性可表示为:对于所有的 t , t ′ ∈ [ ϵ , T ] t,t'\in[\epsilon,T] t,t[ϵ,T],都有 f ( x t , t ) = f ( x t ′ , t ′ ) f(\mathbf{x}_t,t)=f(\mathbf{x}_{t'},t') f(xt,t)=f(xt,t)。我们的 Consistency Model f θ f_\theta fθ 就是要通过学习出自一致性,来拟合这个一致性函数 f f f​。

在这里插入图片描述

参数化

满足自一致性的一致性函数,一定有一个边界条件: f ( x ϵ , ϵ ) = x ϵ f(\mathbf{x}_\epsilon,\epsilon)=\mathbf{x}_\epsilon f(xϵ,ϵ)=xϵ,即不管输入是什么, f ( ⋅ , ϵ ) f(\mathbf{\cdot,\epsilon}) f(,ϵ) 是一个恒等映射。所有的一致性函数都需要满足这个边界条件,这是 CM 成功训练的关键,也是 CM 在结构上的一种限制。对于基于神经网络的 CM,我们讨论两种满足边界条件的形式。假设我们有一个输入输出的维度相同的、无任何限制的神经网络 F θ ( x , t ) F_\theta(\mathbf{x},t) Fθ(x,t)

第一种满足边界条件的形式是直接将 CM 参数化为:
f θ ( x , t ) = { x t = ϵ F θ ( x , t ) t ∈ ( ϵ , T ) f_\theta(\mathbf{x},t)= \begin{cases} \mathbf{x}\quad &t=\epsilon \\ F_\theta(\mathbf{x},t)&t\in(\epsilon,T) \end{cases} fθ(x,t)={xFθ(x,t)t=ϵt(ϵ,T)
第二种参数化的形式是使用跳跃连接(skip connections):
f θ ( x , t ) = c skip ( t ) x + c out F θ ( x , t ) f_\theta(\mathbf{x},t)=c_\text{skip}(t)\mathbf{x}+c_\text{out}F_\theta(\mathbf{x},t) fθ(x,t)=cskip(t)x+coutFθ(x,t)
其中 c skip ( t ) c_\text{skip}(t) cskip(t) c out ( t ) c_\text{out}(t) cout(t) 都是可微函数,且有 c skip ( ϵ ) = 1 , c out ( ϵ ) = 0 c_\text{skip}(\epsilon)=1,\quad c_\text{out}(\epsilon)=0 cskip(ϵ)=1,cout(ϵ)=0。这样,也能满足边界条件。我们注意到,第二种 CM 的参数化形式与 EDM、eDiff-i 等扩散模型的形式非常像,有许多可以借鉴的模型结构。因此我们选择第二种参数化形式。

采样

当一致性模型 f θ ( ⋅ , ⋅ ) f_\theta(\cdot,\cdot) fθ(,) 训练完成之后,我们就可以进行采样生成样本。只需先从初始化分布中采样一个 x ^ T ∼ N ( 0 , T 2 I ) \hat{\mathbf{x}}_T\sim\mathcal{N}(0,T^2\mathbf{I}) x^TN(0,T2I) ,然后可以使用一致性模型直接得到生成结果 x ^ ϵ = f θ ( x ^ T , T ) \hat{\mathbf{x}}_\epsilon=f_\theta(\hat{\mathbf{x}}_T,T) x^ϵ=fθ(x^T,T)​。这仅需要一次 Consistency Model 的前向推理过程。同时,我们也可以选择执行多次推理,交替进行去噪和噪声注入,来实现多步生成。

具体来说,多步采样的过程如以下算法 1 所示。首先使用 CM 根据 x ^ T \hat{\mathbf{x}}_T x^T 一步生成,初始化一个预测的数据样本 x \mathbf{x} x,然后将这个 x \mathbf{x} x 加噪成中间步 τ n \tau_n τn 的噪声图,然后再用 CM 对这一步的噪声图进行一步生成,得到新的预测数据样本 x \mathbf{x} x。然后不断重复交替进行加噪、去噪的过程共 N N N 次,得到最终 N N N​​ 步的生成结果。

在这里插入图片描述

通过多步生成,我们可以提高提升生成结果的质量。这样,我们就能灵活地在采样速度和样本质量之间进行权衡。

这里有一个示意图,将 CM 多步采样的过程清楚地展示了出来。图片来源:Consistency is All You Need

在这里插入图片描述

cr: wrong.wang/blog/20231111-consistency-is-all-you-need/

蒸馏训练 CM

我们首先介绍第一种训练 CM 的方法:对一个预训练的得分模型 s ϕ ( x , t ) \mathbf{s_\phi}(\mathbf{x},t) sϕ(x,t) 进行蒸馏。我们还是围绕公式 3 中的 empirical PF ODE 展开讨论,将时间区间 [ ϵ , T ] [\epsilon,T] [ϵ,T] 离散化为 N − 1 N-1 N1 个子区间,边界分别为 t 1 = ϵ < t 2 < ⋯ < t N = T t_1=\epsilon<t_2<\dots<t_N=T t1=ϵ<t2<<tN=T。在实现中,我们参考 EDM 中的公式来确定区间边界 t i = ( ϵ 1 / ρ + i − 1 N − 1 ( T 1 / ρ − ϵ 1 − ρ ) ) ρ t_i=(\epsilon^{1/\rho}+\frac{i-1}{N-1}(T^{1/\rho}-\epsilon^{1-\rho}))^\rho ti=(ϵ1/ρ+N1i1(T1/ρϵ1ρ))ρ,其中 ρ = 7 \rho=7 ρ=7。当 N N N 足够大时,我们就可以通过执行数值 ODE 求解器的一个离散步,来根据 x t n + 1 \mathbf{x}_{t_{n+1}} xtn+1 得到 x t n \mathbf{x}_{t_n} xtn 的精确估计值。我们记该估计值为 x ^ t n ϕ \hat{\mathbf{x}}_{t_n}^\phi x^tnϕ,具体定义为:
x ^ t n ϕ : = x t n + 1 + ( t n − t n + 1 ) Φ ( x t n + 1 , t n + 1 ; ϕ ) \hat{\mathbf{x}}_{t_n}^\phi:=\mathbf{x}_{t_{n+1}}+(t_n-t_{n+1})\Phi(\mathbf{x}_{t_{n+1}},t_{n+1};\phi) x^tnϕ:=xtn+1+(tntn+1)Φ(xtn+1,tn+1;ϕ)
其中 Φ ( …   ; ϕ ) \Phi(\dots;\phi) Φ(;ϕ) 表示对 empirical PF ODE 执行一步 ODE 求解器的更新方程。比如说当使用 Euler 求解器时,有 Φ ( x , t ; ϕ ) = − t s ϕ ( x , t ) \Phi(\mathbf{x},t;\phi)=-t\mathbf{s_\phi}(\mathbf{x},t) Φ(x,t;ϕ)=tsϕ(x,t),对应于以下更新公式:
x ^ t n ϕ : = x t n + 1 − ( t n − t n + 1 ) t n + 1 s ϕ ( x t n + 1 , t n + 1 ) \hat{\mathbf{x}}_{t_n}^\phi:=\mathbf{x}_{t_{n+1}}-(t_n-t_{n+1})t_{n+1}\mathbf{s_\phi}(\mathbf{x}_{t_{n+1}},t_{n+1}) x^tnϕ:=xtn+1(tntn+1)tn+1sϕ(xtn+1,tn+1)
简单起见,我们这里先只考虑单步的 ODE 求解器。这里其实就相当于是离散扩散模型根据 x t + 1 \mathbf{x}_{t+1} xt+1 进行单步去噪得到上一时间步 x t \mathbf{x}_{t} xt 的过程。

根据公式 2 PF ODE 与公式 1 SDE 之间的联系,我们可以通过先采样 x ∼ p data \mathbf{x}\sim p_\text{data} xpdata,然后向 x \mathbf{x} x 中添加高斯噪声的方式,来采样出 ODE 轨迹。具体来说,给定一个数据样本 x \mathbf{x} x,根据 SDE 的转移密度 N ( x , t n + 1 2 I ) \mathcal{N}(\mathbf{x},t^2_{n+1}\mathbf{I}) N(x,tn+12I) 采样出 x t n + 1 \mathbf{x}_{t_{n+1}} xtn+1,然后使用 ODE 求解器的一步更新公式,计算出前一时刻的 x ^ t n ϕ \hat{\mathbf{x}}_{t_n}^\phi x^tnϕ,从而我们就能得到相邻时刻的数据点的 pair 对 ( x ^ t n ϕ , x t n + 1 ) (\hat{\mathbf{x}}^\phi_{t_n},\mathbf{x}_{t_{n+1}}) (x^tnϕ,xtn+1)。然后,我们就可以以最小化这个相邻时间步 pair 对一致性输出的差异为训练目标,训练 CM。具体来说,通过蒸馏的方式训练 CM 的 consistency distillation 损失定义为:
L C D N ( θ , θ − ; ϕ ) : = E [ λ ( t n ) d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x ^ t n ϕ , t n ) ) ] \mathcal{L}_{CD}^N(\theta,\theta^-;\phi):=\mathbb{E}[\lambda(t_n)d(f_\theta(\mathbf{x}_{t_{n+1}},t_{n+1}),f_{\theta^-}(\hat{\mathbf{x}}^\phi_{t_n},t_n))] LCDN(θ,θ;ϕ):=E[λ(tn)d(fθ(xtn+1,tn+1),fθ(x^tnϕ,tn))]
这里的期望是关于 x ∼ p data , n ∼ U [ 1 , N − 1 ] , x t n + 1 ∼ N ( x ; t n + 1 2 I ) \mathbf{x}\sim p_\text{data},n\sim\mathcal{U}[1,N-1],\mathbf{x}_{t_{n+1}}\sim\mathcal{N}(\mathbf{x};t^2_{n+1}\mathbf{I}) xpdata,nU[1,N1],xtn+1N(x;tn+12I)。其中 λ ( ⋅ ) ∈ R + \lambda(\cdot)\in\mathbb{R}^+ λ()R+ 是正值加权函数,一般取恒等加权函数 λ ( t n ) = 1 \lambda(t_n)=1 λ(tn)=1 即可。, d ( ⋅ , ⋅ ) d(\cdot,\cdot) d(,) 是某种距离度量,文中实验了 L1、L2 和 LPIPS 三种距离。使用随机梯度下降优化 CM 的参数 θ \theta θ,并使用 EMA 来更新参数 θ − \theta^- θ,EMA 更新公式为:

θ − ← stopgrad ( μ θ − + ( 1 − μ ) θ ) \theta^-\leftarrow\text{stopgrad}(\mu\theta^-+(1-\mu)\theta) θstopgrad(μθ+(1μ)θ)
其中 θ − \theta^- θ 是优化过程中滑动平均值的上一步, 0 ≤ μ < 1 0\le\mu<1 0μ<1

原生训练 CM

原生训练 CM 时,我们不需要预训练的扩散模型 s ϕ ( x , t ) \mathbf{s_\phi}(\mathbf{x},t) sϕ(x,t) 来估计分布得分 ∇ log ⁡ p t ( x t ) \nabla\log p_t(\mathbf{x}_t) logpt(xt) 了。此时,我们可以根据 x \mathbf{x} x x t \mathbf{x}_t xt 来估计得分:
∇ log ⁡ p t ( x t ) = − E [ x t − x t 2 ∣ x t ] \nabla\log p_t(\mathbf{x}_t)=-\mathbb{E}[\frac{\mathbf{x}_t-\mathbf{x}}{t^2}|\mathbf{x}_t] logpt(xt)=E[t2xtxxt]
其中 x ∼ p data \mathbf{x}\sim p_\text{data} xpdata x t ∼ N ( x ; t 2 I ) \mathbf{x}_t\sim\mathcal{N}(\mathbf{x};t^2\mathbf{I}) xtN(x;t2I)。当我们使用 Euler ODE 求解器,在 N → ∞ N\rightarrow\infty N 时,这个无偏估计足以替代上述蒸馏训练 CM 时的预训练扩散模型,原文定理 2 给出了证明。

consistency training 的目标函数定义为:
L ( θ , θ − ) : = E [ λ ( t n ) d ( f θ ( x + t n + 1 z , t n + 1 ) , f θ − ( x + t n z , t n ) ) ] \mathcal{L}(\theta,\theta^-):=\mathbb{E}[\lambda(t_n)d(f_\theta(\mathbf{x}+t_{n+1}\mathbf{z},t_{n+1}),f_{\theta^-}(\mathbf{x}+t_n\mathbf{z},t_n))] L(θ,θ):=E[λ(tn)d(fθ(x+tn+1z,tn+1),fθ(x+tnz,tn))]
注意这里的损失 L ( θ , θ − ) \mathcal{L}(\theta,\theta^-) L(θ,θ) 仅与 CM 参数 θ , θ − \theta,\theta^- θ,θ 有关,与预训练扩散模型参数 ϕ \phi ϕ 无关。

总结

本文首先回顾了连续时间 SDE/ODE 扩散模型相关的基础知识,并选定了 EDM 的形式化。然后介绍 Consistency Models,考虑一个一致性函数 f f f,其可以从任意时刻一步映射到一致的结果 x ϵ \mathbf{x}_\epsilon xϵ。训练一个神经网络 f θ f_\theta fθ 来拟合这个 f f f,即为 CM。CM 可以从纯噪声一步采样得到生成样本,也可以多步采样来均衡生成质量和采样速度。CM 的训练模式有蒸馏和原生训练两种,核心都是构建相邻时刻数据点 pair 对,然后最小化 CM f θ f_\theta fθ 对二者输出值的差异,生成一致性的结果。区别在于,蒸馏训练模式使用预训练的扩散模型来得到相邻时刻的 pair 对,而原生训练模式则直接通过扩散模型前向加噪公式来进行无偏估计。

;