扩散模型总结

扩散模型总结(附代码)

最近在回顾之前写的Understanding Diffusion Models时,发现有个问题,文章太细了,每一步都是数学推导,于是,这篇对其进行一次总结,使得对模型理解一目了然。

前向过程

q(xtxt1)=N(xt;αtxt1,(1αt)I)q(x_t|x_{t-1})=\mathcal N(x_t;\sqrt{\alpha_t}x_{t-1},(1-\alpha_t)I)

x0q(x0)x_0\sim q(x_0),β1,...,βT\beta_1,...,\beta_T is the variance schedule, αt=1βt\alpha_t=1-\beta_t

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar\alpha_t}x_0,(1-\bar\alpha_t)I)

αˉt=s=1tαs\bar\alpha_t=\prod_{s=1}^t\alpha_s .

xtq(xtx0)x_t\sim\mathcal q(x_t|x_0) 代码实现:

1
2
3
4
5
6
7
8
9
def q_xt_x0(self, x0: torch.Tensor, t:torch.tensor):
mean = gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - gather(self.alpha_bar, t)
return mean, var
def q_sample(self, x0: torch.Tensor, t:torch.tensor, eps: Optional[torch.Tensor] = None):
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
return mean + (var ** 0.5) * eps

反向过程

q(xt1xt,x0)=N(xt1;μq(xt,x0),Σq(t))q(x_{t-1}|x_t,x_0)=\mathcal N(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))

μq(xt,x0)=1αtxt1αt1αˉtαtϵ0.\mu_q(x_t,x_0)=\frac1{\sqrt{\alpha_t}}x_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\epsilon_0.

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta(x_{t-1}|x_t)=\mathcal N(x_{t-1};\mu_\theta(x_t,t),\Sigma_\theta(x_t,t))

μθ(xt,t)=1αtxt1αt1αˉtαtϵ^θ(xt,t), Σq(t)=Σθ(t)=σt2I.\mu_\theta(x_t,t)=\frac1{\sqrt{\alpha_t}}x_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\hat\epsilon_\theta(x_t,t),\ \Sigma_q(t)=\Sigma_\theta(t)=\sigma_t^2I.

损失

Lsimple(θ)=Et,x0,ϵ[ϵϵ^θ(xt,t)2]=Et,x0,ϵ[ϵϵ^θ(αˉtx0+1αˉtϵ,t)2]L_{simple}(\theta) = \mathbb E_{t,x_0,\epsilon}[\left\|\epsilon-\hat\epsilon_\theta(x_t,t)\right\|^2]\\ =\mathbb E_{t,x_0,\epsilon}[\left\|\epsilon-\hat\epsilon_\theta(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon,t)\right\|^2]

代码实现:

1
2
3
4
5
6
7
8
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
if noise is None:
noise = torch.randn_like(x_0)
xt = q_sample(x0, t, eps=noise)
eps_theta = self.eps_model(xt, t)
return F.mse_loss(noise, eps_theta)

采样

xt1pθ(xt1xt)x_{t-1}\sim p_\theta(x_{t-1}|x_t) 代码实现:

1
2
3
4
5
6
7
8
9
def p_sample(self,xt: torch.Tensor, t:torch.tensor):
eps_theta = self.eps_model(xt, t)
alpha_bar = gather(self.alpha_bar, t)
alpha = gather(self.alpha, t)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
var = gather(self.sigma2, t)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** 0.5) * eps

参考

[1] J. Ho, A. Jain, and P. Abbeel, “Denoising Diffusion Probabilistic Models.” arXiv, Dec. 16, 2020. doi: 10.48550/arXiv.2006.11239.

代码来源:https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/diffusion/ddpm/


扩散模型总结
https://summerwrain.github.io/2023/12/08/扩散模型总结/
作者
SummerRain
发布于
2023年12月8日
许可协议