Diffusion GAN

Tackling the Generative Learning Trilemma with Denoising Diffusion GANs

背景

生成式学习三困境:高质量采样,多样性,快速采样。

传统扩散每一步的高斯假设只在去噪步长非常小的时候成立,因此反向过程需要很多步,而反向过程中步长较大(更少的去噪步数)时,需要使用非高斯多峰分布来建模去噪分布。

贡献

  1. 将扩散模型缓慢采样的原因归结为每一步去噪分布的高斯假设(pθ(xt1xt)p_\theta(x_{t-1}|x_t) 被建模为高斯分布),并提出使用复杂、多峰的去噪分布。
  2. 提出 Denoising Diffusion GANs,使用 conditional GANs参数化扩散模型的反向过程。
  3. 与现存的扩散模型相比,diffusion GAN实现了几个数量级的加速。

方法

Multimodel Denoising Distributions for Large Denoising Steps

高斯假设在何时正确?

  1. 当步长 βt\beta_t 无穷小时,q(xt1xt)q(xtxt1)q(xt1)q(x_{t-1}|x_t)\propto q(x_t|x_{t-1})q(x_{t-1})q(xtxt1)q(x_t|x_{t-1}) 主导,于是反向过程与前向过程拥有相同的函数形式。由于 q(xtxt1)q(x_t|x_{t-1}) 为高斯,则 q(xt1xt)q(x_{t-1}|x_t) 也为高斯。因此普通的 Diffusion Models 往往需要上千步来保证 βt\beta_t 足够小。
  2. 如果 q(x0)q(x_0) 为高斯,则 q(xt1xt)q(x_{t-1}|x_t) 也为高斯分布:LSGM提出使用VAE encoder 使得 q(x0)q(x_0)q(xt)q(x_t) 逼近于高斯分布,但是将数据转换为高斯分布本身就很难,并且VAE encoder 也无法完美地解决它,因此LSGM在复杂数据集上也需要10~100步。

当两个条件都不满足时(去噪步长很大,且数据分布不是高斯分布),此时无法保证去噪分布为高斯分布。

当去噪步长越大时,真实的去噪分布变得越复杂和多峰。

Modeling Denoising Distributions With Conditional GANs

目标:减少反向过程中的去噪步数 T

方法:使用 conditional GANs 近似真实的去噪分布 q(xt1xt)q(x_{t-1}|x_t)

q(xtxt1)=N(xt;1βtxt1,βtI)(1)q(x_t|x_{t-1})=\mathcal N(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t I) \tag{1}

前向过程与普通扩散过程 (1)(1) 类似,而主要的假设为:T 被假设的非常小(T8T\le 8),且每个扩散步有更大的 βt\beta_t .

训练过程: 每一步去噪使用对抗损失最小化 DadvD_{adv} 来匹配 conditional GAN 生成器 pθ(xt1xt)p_\theta(x_{t-1}|x_t)q(xt1xt)q(x_{t-1}|x_t)

minθt1Eq(xt)[Dadv(q(xt1xt)pθ(xt1xt))](2)\underset{\theta}{min} \sum_{t\ge 1}\mathbb E_{q(x_t)}[D_{adv}(q(x_{t-1}|x_t)\parallel p_\theta(x_{t-1}|x_t))]\tag{2}

DadvD_{adv} 可以是 Wasserstein 距离、Jenson-Shannon 散度、f 散度。作者这里使用一个特殊的 f 散度—— softened reverse KL。记一个与时间相关的判别器 Dϕ(xt1,xt,t):RN×RN×N[0,1]D_\phi(x_{t-1},x_t,t):\mathbb R^N\times \mathbb R^N\times \mathbb N\to[0,1] ,参数为 ϕ\phi 。输入为 N 维的 xt1x_{t-1}xtx_t ,判断 xt1x_{t-1} 是否为 xtx_t 的去噪版本。

判别器训练

minϕt1Eq(xt)[Eq(xt1xt)[log(Dϕ(xt1,xt,t))]+Epθ(xt1xt)[log(1Dϕ(xt1,xt,t))]](3)\underset{\phi}{min}\sum_{t\ge1}\mathbb E_{q(x_t)}\big[\mathbb E_{q(x_{t-1}|x_t)}[-\log(D_\phi(x_{t-1},x_t,t))]+\mathbb E_{p_\theta(x_{t-1}|x_t)}[-\log(1-D_\phi(x_{t-1},x_t,t))]\big]\tag{3}

由于第一项期望中的 q(xt1xt)q(x_{t-1}|x_t) 未知,但是

q(xt,xt1)=dx0q(x0)q(xt1,xtx0)=dx0q(x0)q(xt1x0)q(xtxt1,x0)=dx0q(x0)q(xt1x0)q(xtxt1)(4)q(x_t,x_{t-1})=\int \mathrm dx_0q(x_0)q(x_{t-1},x_t|x_0)\\ =\int\mathrm dx_0q(x_0)q(x_{t-1}|x_0)q(x_t|x_{t-1},x_0)\\ =\int\mathrm dx_0q(x_0)q(x_{t-1}|x_0)q(x_t|x_{t-1})\tag{4}

于是 (3)(3) 中第一项的期望可写为:

Eq(xt)q(xt1xt)[log(Dϕ(xt1,xt,t))]=Eq(x0)q(xt1x0)q(xtxt1)[log(Dϕ(xt1,xt,t))](5)\mathbb E_{q(x_t)q(x_{t-1}|x_t)}[-\log(D_\phi(x_{t-1},x_t,t))]=\mathbb E_{q(x_0)q(x_{t-1}|x_0)q(x_t|x_{t-1})}[-\log(D_\phi(x_{t-1},x_t,t))]\tag{5}

生成器训练

maxθt1Eq(xt)Epθ(xt1xt)[log(Dϕ(xt1,xt,t))](6)\underset{\theta}{max}\sum_{t\ge1}\mathbb E_{q(x_t)}\mathbb E_{p_\theta(x_{t-1}|x_t)}[\log(D_\phi(x_{t-1},x_t,t))]\tag{6}

参数化隐式去噪模型

pθ(xt1xt):=pθ(x0xt)q(xt1xt,x0)dx0=p(z)q(xt1xt,x0=Gθ(xt,z,t))dz(7)p_\theta(x_{t-1}|x_t):=\int p_\theta(x_0|x_t)q(x_{t-1}|x_t,x_0)\mathrm dx_0\\ =\int p(z)q(x_{t-1}|x_t,x_0=G_\theta(x_t,z,t))\mathrm dz\tag{7}

这里 zp(z):=N(z;0,I)z\sim p(z):=\mathcal N(z;0,I)pθ(x0xt)p_\theta(x_0|x_t) 是一个隐式分布,通过GAN的生成器Gθ(xt,z,t):RN×RL×RRNG_\theta(x_t,z,t):\mathbb R^N\times \mathbb R^L\times\mathbb R\to \mathbb R^N (给定 xtx_tzz 输出 x0x_0)施加。

DDPM中的 x0x_0 通过与 xtx_t 的确定性映射来预测,这里的 x0x_0 通过生成器来预测。由于 xtx_t 有不同水平的扰动,因此直接预测 xt1x_{t-1} 较为困难,但是作者这里是直接预测没有扰动的 x0x_0 ,因此可以避免上述问题。

相对于一步生成器的优势:众所周知,GAN的训练不稳定,且容易模式崩溃,原因可能是直接一步从复杂分布生成样本较为困难,以及判别器只看到正确样本时的过拟合问题。但是本文中的模型将生成过程分解为一系列条件去噪扩散过程,由于 xtx_t 的强条件作用,因此每一步的建模比较简单,此外,扩散过程使得数据分布更平滑,判别器不容易过拟合。

实验

参考

[1] Z. Xiao, K. Kreis, and A. Vahdat, “Tackling the Generative Learning Trilemma with Denoising Diffusion GANs.” arXiv, Apr. 04, 2022. doi: 10.48550/arXiv.2112.07804.


Diffusion GAN
https://summerwrain.github.io/2023/12/12/Diffusion GAN/
作者
SummerRain
发布于
2023年12月12日
许可协议