IR-SDE

Image Restoration with Mean-Reverting Stochastic Differential Equations

背景

以往的方法的反向过程会初始化一个较高方差的噪声,可能会导致ground-truth高质量图像的恢复效果不好,虽然可以获得较高的感知分数,但是往往会在基于像素或结构的退化标准方面表现不太满意。

贡献

为一般的IR问题提出一种SDE方法。mean-recerting SDE关键思想:将高质量的图像转换为退化图像,作为具有固定高斯噪声的均值状态,然后去估计reverse-time SDE,不需要有任何task-specific的先验知识。提出了最大似然目标学习优化的reverse轨迹。

方法

前向过程

dx=θt(μx)dt+σtdwdx=\theta_t(\mu-x)dt+\sigma_tdw

μ\mu 为状态均值,θt,σt\theta_t,\sigma_t 是与时间相关的正参数,分别表征了均值恢复和随机波动的速度。μ\mux(0)x(0) 分别为退化的LQ图像和ground-truth的HQ图像对。由于μ\mu 依赖于x(0)x(0)x(0)x(0) 与布朗运动无关,因此该SDE在伊藤形式上依然有效。
σt2/θt=2λ2\sigma_t^2/\theta_t=2\lambda^2λ2\lambda^2 为固定的方差。给定任意的 x(s),s<tx(s),s<t ,SDE的解为:

x(t)=μ+(x(s)μ)eθˉs:t+stσzeθˉz:tdw(z)x(t)=\mu+(x(s)-\mu)e^{-\bar\theta_{s:t}}+\int_s^t\sigma_ze^{-\bar\theta_{z:t}}dw(z)

θˉs:t:=stθzdz\bar\theta_{s:t}:=\int_s^t\theta_zdz 已知,则中间的转移步 p(x(t)x(s))=N(x(t)ms:t(x(s)),vs:t)p(x(t)|x(s))=\mathcal N(x(t)|m_{s:t}(x(s)),v_{s:t})

ms:t(x(s)):=μ+(x(s)μ)eθˉs:t,vs:t:=stσz2e2θˉz:tdz=λ2(1e2θˉs:t)m_{s:t}(x(s)):=\mu+(x(s)-\mu)e^{-\bar\theta_{s:t}},\\ v_{s:t}:=\int_s^t\sigma_z^2e^{-2\bar\theta_{z:t}}dz\\ =\lambda^2(1-e^{-2\bar\theta_{s:t}})

证明:

需求解的SDE:

dx=θt(μx)dt+σtdw\mathrm dx=\theta_t(\mu-x)\mathrm dt+\sigma_t \mathrm dw

为简便,使用 θˉt\bar\theta_t 替换 θˉ0:t\bar\theta_{0:t} ,定义一个代理可微函数(方便求解):

ψ(x,t)=xeθˉt\psi(x,t)=x\mathrm e^{\bar\theta_t}

根据伊藤公式:

dψ(x,t)=ψt(x,t)dt+ψx(x,t)f(x,t)dt+122ψx2(x,t)g(t)2dt+ψx(x,t)g(t)dwt\mathrm d\psi(x,t)=\frac{\partial\psi}{\partial t}(x,t)\mathrm dt+\frac{\partial\psi}{\partial x}(x,t)\mathbf f(x,t)\mathrm dt\\ +\frac12\frac{\partial^2\psi}{\partial x^2}(x,t)g(t)^2\mathrm dt\\ +\frac{\partial\psi}{\partial x}(x,t)g(t)\mathrm {dw_t}

将式 (4)(4) 代入 (6)(6) 得:

dψ(x,t)=μθteθˉtdt+σteθˉtdwt\mathrm d\psi(x,t)=\mu\theta_t\mathrm e^{\bar\theta_t}\mathrm dt+\sigma_t\mathrm e^{\bar\theta_t}\mathrm{dw}_t

积分得:

ψ(x,t)ψ(x,s)=stμθzeθˉzdz+stσzeθˉzdw(z)\psi(x,t)-\psi(x,s)=\int_s^t\mu\theta_z\mathrm e^{\bar\theta_z}\mathrm dz+\int_s^t\sigma_z\mathrm e^{\bar\theta_z}\mathrm{dw}(z)

在这里,stσzeθˉzdw(z)N(0,stσz2e2θˉzdz)\int_s^t\sigma_z\mathrm e^{\bar\theta_z}\mathrm{dw}(z)\sim\mathcal N(0,\int_s^t\sigma^2_z\mathrm e^{2\bar\theta_z}dz)

证明:

Var(stHzdW(z))Var(i=1NHti(WtiWti1))=i=1NVar(Hti(WtiWti1))=i=1NHti2Var(WtiWti1)=i=1NHti2(titi1)Var(\int_s^tH_zdW(z))\approx Var(\sum_{i=1}^NH_{t_i}(W_{t_i}-W_{t_{i-1}}))\\ =\sum_{i=1}^NVar(H_{t_i}(W_{t_i}-W_{t_{i-1}}))\\ =\sum_{i=1}^NH_{t_i}^2Var(W_{t_i}-W_{t_{i-1}})\\ =\sum_{i=1}^NH_{t_i}^2(t_i-t_{i-1})

limmaxΔti0i=1NHti2(titi1)=stHz2dz\lim_{max \Delta t_i\to0}\sum_{i=1}^NH_{t_i}^2(t_i-t_{i-1})=\int_s^tH_z^2dz

由于 dθˉt=d0tθzdz=θt\mathrm d\bar\theta_t=\mathrm d\int_0^t\theta_z\mathrm dz=\theta_t ,于是:

x(t)eθˉtx(s)eθˉs=μ(eθˉteθˉs)+stσzeθˉzdw(z)x(t)\mathrm e^{\bar\theta_t}-x(s)\mathrm e^{\bar\theta_s}=\mu(\mathrm e^{\bar\theta_t}-\mathrm e^{\bar\theta_s})+\int_s^t\sigma_z\mathrm e^{\bar\theta_z}\mathrm{dw}(z)

两边同时除以 eθˉt\mathrm e^{\bar\theta_t} 得:

x(t)=μ+(x(s)μ)eθˉs:t+stσzeθˉz:tdwzx(t)=\mu+(x(s)-\mu)\mathrm e^{-\bar\theta_{s:t}}+\int_s^t\sigma_z\mathrm e^{-\bar\theta_{z:t}}\mathrm {dw}_z

stσz2e2θˉz:tdz=σt22θte0σs22θse2θˉs:t=λ2(1e2θˉs:t)\int_s^t\sigma_z^2\mathrm e^{-2\bar\theta_{z:t}}\mathrm dz=\frac{\sigma_t^2}{2\theta_t}e^0-\frac{\sigma^2_s}{2\theta_s}\mathrm e^{-2\bar\theta_{s:t}}=\lambda^2(1-\mathrm e^{-2\bar\theta_{s:t}})

证毕

tt\to\infty 时,mtm_t 趋于LQ图像 μ\muvtv_t 趋于 λ2\lambda^2 ,即SDE将HQ图像扩散到LQ图像加一张固定高斯噪声。

反向过程

dx=[θt(μx)σt2xlogpt(x)]dt+σtdw^\mathrm dx=[\theta_t(\mu-x)-\sigma^2_t\nabla_x\log p_t(x)]dt+\sigma_t\mathrm d\hat w

每一步的ground-truth score:

xlogpt(xx(0))=x(t)mt(x)vt\nabla_x\log p_t(x|x(0))=-\frac{x(t)-m_t(x)}{v_t}

重新参数化 x(t)=mt(x)+vtϵt,ϵn(0,I)x(t)=m_t(x)+\sqrt{v_t}\epsilon_t,\epsilon\sim\mathcal n(0,I) ,于是:

xlogpt(xx(0))=ϵtvt\nabla_x\log p_t(x|x(0))=-\frac{\epsilon_t}{\sqrt v_t}

训练噪声估计网络 ϵ~ϕ(x(t),μ,t)\tilde\epsilon_\phi(x(t),\mu,t) :

Lγ(ϕ):=i=1TγiE[ϵ~ϕ(x(t),μ,t)ϵi]L_\gamma(\phi):=\sum_{i=1}^{T}\gamma_i\mathbb E[\left\|\tilde\epsilon_\phi(x(t),\mu,t)-\epsilon_i\right\|]

上式虽然比较简单,但是往往应用到复杂退化场景时会不稳定。原因:试图学习给定时间的瞬时噪声。

解决方法: 提出最大似然目标尝试寻找给定 x0x_0x1:Tx_{1:T} 的最优轨迹(而不是更准确的分数函数)

最大似然目标:

p(x1:Tx0)=p(xTx0)i=2Tp(xi1xi,x0)p(x_{1:T}|x_0)=p(x_T|x_0)\prod_{i=2}^Tp(x_{i-1}|x_i,x_0)

于是最佳反向状态为最小化负的对数似然:

xi1=arg minxi1[logp(xi1xi,x0)]x_{i-1}^*=\underset{x_{i-1}}{\mathrm{arg\ min}}[-\log p(x_{i-1}|x_i,x_0)]

θi:=i1iθtdt\theta'_i:=\int_{i-1}^i\theta_t\mathrm dt ,通过求解可得最优轨迹:

xi1=1e2θˉi11e2θˉieθi(xiμ)+1e2θi1eθˉi(x0μ)+μx_{i-1}^*=\frac{1-\mathrm e^{-2\bar\theta_{i-1}}}{1-\mathrm e^{-2\bar\theta_i}}\mathrm e^{-\theta'_i}(x_i-\mu)\\ +\frac{1-\mathrm e^{-2\theta_i'}}{1-\mathrm e^{\bar\theta_i}}(x_0-\mu)+\mu

证明:

根据贝叶斯规则

logp(xi1xi,x0)=logp(xixi1,x0)p(xi1x0)p(xix0)logp(xixi1,x0)logp(xi1x0)-\log p(x_{i-1}|x_i,x_0)=-\log\frac{p(x_i|x_{i-1},x_0)p(x_{i-1}|x_0)}{p(x_i|x_0)}\\ \propto -\log p(x_i|x_{i-1},x_0)-\log p(x_{i-1}|x_0)

令负的对数似然的梯度为0:

证毕,二阶导数为正,因此 xi1x_{i-1}^* 确实为最优点

训练噪声估计网络 ϵ~ϕ(xi,μ,i)\tilde\epsilon_\phi(x_i,\mu,i) :

Jγ(ϕ):=i=1TγiE[xi(dxi)ϵ~ϕxi1]J_\gamma(\phi):=\sum_{i=1}^T\gamma_i\mathbb E[\left\|x_i-(dx_i)_{\tilde\epsilon_\phi}-x_{i-1}^*\right\|]

(dxi)ϵ~ϕ(dx_i)_{\tilde\epsilon_\phi} 为式 (14)(14) 的反向SDE,其中分数由 ϵ~ϕ\tilde\epsilon_\phi 估计,由于扩散项的期望为0,于是只需要考虑漂移项。

实验

Deraining

Debluring

Gaussian Image Denoising

提出Denoising-SDE:将干净图像设为均值 μ=x0\mu=x_0 ,可以以更少的时间步进行去噪。因此可以将任意噪声图作为中间状态,将其直接反转为干净图像。

提出Denoising-ODE:由于只有高斯噪声,因此可以直接使用ODE,没有额外的噪声直接去噪:

dx=[θt(μx)σt2xlogpt(x)]dt\mathrm dx=[\theta_t(\mu-x)-\sigma^2_t\nabla_x\log p_t(x)]dt

pnoise(xix0)=N(mi(x0),vi)p_{noise}(x_i|x_0)=\mathcal N(m_i(x_0),v_i)

这里的 mi(x0)=x0,vi=λ2(1e2θˉi)m_i(x_0)=x_0,v_i=\lambda^2(1-e^{-2\bar\theta_i}) ,最优轨迹变为:

xi1=1e2θˉi11e2θˉieθi(xix0)+x0x_{i-1}^*=\frac{1-\mathrm e^{-2\bar\theta_{i-1}}}{1-\mathrm e^{-2\bar\theta_i}}\mathrm e^{-\theta'_i}(x_i-x_0)+x_0

xi=mi+viϵt,ϵN(0,I)x_i=m_i+\sqrt{v_i}\epsilon_t,\epsilon\sim\mathcal N(0,I)

将式 (26)(26) 代入式 (14)(14) 得Denoising-SDE:

对应的ODE:

dx=12σt2e2θˉtxtlogpt(x)dtdx=-\frac12\sigma_t^2\mathrm e^{-2\bar\theta_t}\nabla_{x_t}\log p_t(x)dt

一旦知道真实噪声水平 σreal\sigma_{real} 则可以直接推出相应的 tt^* 使得 pnoise(xix0)p_{noise}(x_i|x_0) 的方差刚好为噪声水平:

σreal2=vt=λ2(1e2θˉt)\sigma_{real}^2=v_t=\lambda^2(1-\mathrm e^{-2\bar\theta_t})

于是:

t=arg mintθˉt12Δtlog(1σreal2λ2)t^*=\underset{t}{arg\ min}\left\|\bar\theta_t-\frac1{2\Delta t}\log (1-\frac{\sigma_{real}^2}{\lambda^2 })\right\|

局限与未来方向

vtv_t 中的指数项会使得在最后几步的方差变得过于平滑,使得相邻状态很相似,使得学习变得困难。解决方案:改进 θ\theta 时间表。

采样步骤的优化,以减少计算成本

参考

[1] Z. Luo, F. K. Gustafsson, Z. Zhao, J. Sjölund, and T. B. Schön, “Image Restoration with Mean-Reverting Stochastic Differential Equations.” arXiv, May 31, 2023. Accessed: Nov. 13, 2023. [Online]. Available: http://arxiv.org/abs/2301.11699


IR-SDE
https://summerwrain.github.io/2023/11/13/IR-SDE/
作者
SummerRain
发布于
2023年11月13日
许可协议