HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: scalerel
  • failed: stackengine

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2305.11650v6 [stat.ML] 19 Mar 2024
\stackMath

Moment Matching Denoising Gibbs Sampling

Mingtian Zhang
Centre for Artificial Intelligence
University College London
[email protected] &Alex Hawkins-Hooker
Centre for Artificial Intelligence
University College London
[email protected] &Brooks Paige
Centre for Artificial Intelligence
University College London
[email protected] &David Barber
Centre for Artificial Intelligence
University College London
[email protected]
This work was partially done during an internship in Huawei Noah’s Ark Lab.
Abstract

Energy-Based Models (EBMs) offer a versatile framework for modeling complex data distributions. However, training and sampling from EBMs continue to pose significant challenges. The widely-used Denoising Score Matching (DSM) method Vincent (2011) for scalable EBM training suffers from inconsistency issues, causing the energy model to learn a ‘noisy’ data distribution. In this work, we propose an efficient sampling framework, (pseudo)-Gibbs sampling with moment matching, which enables effective sampling from the underlying clean model when given a ‘noisy’ model that has been well-trained via DSM. We explore the benefits of our approach compared to related methods and demonstrate how to scale the method to high-dimensional datasets.

1 Energy-Based Models

Energy-Based Models (EBMs) have attracted a lot of attention in the generative model literature Ngiam et al. (2011); Xie et al. (2016); Du and Mordatch (2019); Song and Ermon (2019). EBMs are a type of non-normalized probabilistic model that determines the probability density function without a known normalizing constant. For continuous data x𝑥xitalic_x, the density function of an EBM is specified as qθ(x)=exp(fθ(x))/Z(θ)subscript𝑞𝜃𝑥subscript𝑓𝜃𝑥𝑍𝜃q_{\theta}(x)=\exp(-f_{\theta}(x))/Z(\theta)italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = roman_exp ( - italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ) / italic_Z ( italic_θ ) where the fθ(x)subscript𝑓𝜃𝑥f_{\theta}(x)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) is a nonlinear function with parameter θ𝜃\thetaitalic_θ and Z(θ)=exp(fθ(x))dx𝑍𝜃subscript𝑓𝜃𝑥differential-d𝑥Z(\theta)=\int\exp(-f_{\theta}(x))\mathop{}\!\mathrm{d}{x}italic_Z ( italic_θ ) = ∫ roman_exp ( - italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ) roman_d italic_x is the normalization constant that is independent of x𝑥xitalic_x. The energy parameterization allows for greater flexibility in model parameterization and the ability to model a wider range of probability distributions. However, the lack of a known normalizing constant makes training these models challenging. We start by giving a brief introduction of how to estimate θ𝜃\thetaitalic_θ in EBMs and refer the reader to Song and Kingma (2021) for a detailed overview of different training techniques for continuous EBMs.

Likelihood-based training: A classic method to learn θ𝜃\thetaitalic_θ is to minimize the KL divergence between the data distribution pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT and the model density qθsubscript𝑞𝜃q_{\theta}italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, which is defined as

KL(pd||qθ)\displaystyle{\mathrm{KL}}(p_{d}||q_{\theta})roman_KL ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) pd(x)logqθ(x)dxpd(x)fθ(x)dxlogZ(θ),approaches-limitabsentsubscript𝑝𝑑𝑥subscript𝑞𝜃𝑥differential-d𝑥approaches-limitsubscript𝑝𝑑𝑥subscript𝑓𝜃𝑥differential-d𝑥𝑍𝜃\displaystyle\doteq-\int p_{d}(x)\log q_{\theta}(x)\mathop{}\!\mathrm{d}{x}% \doteq-\int p_{d}(x)f_{\theta}(x)\mathop{}\!\mathrm{d}{x}-\log Z(\theta),≐ - ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) roman_log italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) roman_d italic_x ≐ - ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) roman_d italic_x - roman_log italic_Z ( italic_θ ) , (1)

where we use approaches-limit\doteq to denote the equivalence up to a constant that is independent of θ𝜃\thetaitalic_θ. The integration of pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) can be approximated by Monte Carlo with the training dataset 𝒳train={x1,,xN}pd(x)subscript𝒳𝑡𝑟𝑎𝑖𝑛subscript𝑥1subscript𝑥𝑁similar-tosubscript𝑝𝑑𝑥\mathcal{X}_{train}=\{x_{1},\cdots,x_{N}\}\sim p_{d}(x)caligraphic_X start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT = { italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } ∼ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ); in this case, it is equivalent to the maximum likelihood estimate (MLE) Bishop and Nasrabadi (2006). However, for EBMs, minimizing the KL divergence requires the estimation of Z(θ)𝑍𝜃Z(\theta)italic_Z ( italic_θ ), which is intractable for nonlinear fθ(x)subscript𝑓𝜃𝑥f_{\theta}(x)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) defined by a neural network. Various methods have been proposed to alleviate the intractability by introducing techniques like Markov chain Monte Carlo (MCMC) Hinton (2002); Nijkamp et al. (2019); Du et al. (2020); Gao et al. (2018) or adversarial training Kim and Bengio (2016); Zhai et al. (2016); Bose et al. (2018).

Score-based training: Alternatively, Hyvärinen (2005) proposes to minimize the Fisher divergence to learn θ𝜃\thetaitalic_θ, which is defined as

FD(pd||qθ)=12pd(x)||spd(x)sqθ(x)||22dx,\displaystyle{\mathrm{FD}}(p_{d}||q_{\theta})=\frac{1}{2}\int p_{d}(x)||s_{p_{% d}}(x)-s_{q_{\theta}}(x)||^{2}_{2}\mathop{}\!\mathrm{d}{x},roman_FD ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) | | italic_s start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) - italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_d italic_x , (2)

where we use sp(x)subscript𝑠𝑝𝑥s_{p}(x)italic_s start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_x ) to denote the score function of distribution p𝑝pitalic_p: sp(x)xlogp(x)subscript𝑠𝑝𝑥subscript𝑥𝑝𝑥s_{p}(x)\equiv\nabla_{x}\log p(x)italic_s start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_x ) ≡ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT roman_log italic_p ( italic_x ). Under certain regularity conditions, the Fisher divergence is equivalent to the score-matching (SM) objective Hyvärinen (2005),

FD(pd||qθ)12pd(x)(||sqθ(x)||22+2Tr(xsqθ(x)))dx,\displaystyle{\mathrm{FD}}(p_{d}||q_{\theta})\doteq\frac{1}{2}\int p_{d}(x)% \left(||s_{q_{\theta}}(x)||_{2}^{2}+2\operatorname{Tr}(\nabla_{x}s_{q_{\theta}% }(x))\right)\mathop{}\!\mathrm{d}{x},roman_FD ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ( | | italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 roman_Tr ( ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) ) ) roman_d italic_x , (3)

which does not require estimation of the intractable Z(θ)𝑍𝜃Z(\theta)italic_Z ( italic_θ ). However, this objective needs to calculate the Hessian trace xsqθ(x)=x2fθ(x)subscript𝑥subscript𝑠subscript𝑞𝜃𝑥subscriptsuperscript2𝑥subscript𝑓𝜃𝑥\nabla_{x}s_{q_{\theta}}(x)=-\nabla^{2}_{x}f_{\theta}(x)∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) = - ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) in every gradient step during training, which is computationally expensive and does not scale to high dimensional data or requires approximation Song et al. (2020). In this paper, we will focus on another training method, denoising score matching Vincent (2011), which overcomes the tractability and scalability issues mentioned above, and is introduced in the next section.

1.1 Denoising Score Matching

For the target data density pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ), a noise distribution p(x~|x)=𝒩(x,σ2I)𝑝conditional~𝑥𝑥𝒩𝑥superscript𝜎2𝐼p(\tilde{x}|x)=\operatorname{\mathcal{N}}(x,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) is introduced to construct a noised data distribution p~d(x~)=pd(x)p(x~|x)dxsubscript~𝑝𝑑~𝑥subscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥differential-d𝑥\tilde{p}_{d}(\tilde{x})=\int p_{d}(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x. Denoising score matching (DSM) Vincent (2011) minimizes the Fisher divergence between the noised data distribution p~d(x~)subscript~𝑝𝑑~𝑥\tilde{p}_{d}(\tilde{x})over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) and an energy-based model q~θ(x~)=exp(fθ(x~))/Z(θ)subscript~𝑞𝜃~𝑥subscript𝑓𝜃~𝑥𝑍𝜃\tilde{q}_{\theta}(\tilde{x})=\exp(-f_{\theta}(\tilde{x}))/Z(\theta)over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = roman_exp ( - italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) / italic_Z ( italic_θ ), with

FD(p~d||q~θ)\displaystyle{\mathrm{FD}}(\tilde{p}_{d}||\tilde{q}_{\theta})roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) =12p~(x~)sp~d(x~)sq~θ(x~)22dx~absent12~𝑝~𝑥subscriptsuperscriptnormsubscript𝑠subscript~𝑝𝑑~𝑥subscript𝑠subscript~𝑞𝜃~𝑥22differential-d~𝑥\displaystyle=\frac{1}{2}\int\tilde{p}(\tilde{x})||s_{\tilde{p}_{d}}(\tilde{x}% )-s_{\tilde{q}_{\theta}}(\tilde{x})||^{2}_{2}\mathop{}\!\mathrm{d}{\tilde{x}}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ over~ start_ARG italic_p end_ARG ( over~ start_ARG italic_x end_ARG ) | | italic_s start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) - italic_s start_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_d over~ start_ARG italic_x end_ARG
12p(x~|x)pd(x)||x~logp(x~|x)sq~θ(x~)||22dx~dx\displaystyle\doteq\frac{1}{2}\iint p(\tilde{x}|x)p_{d}(x)||\nabla_{\tilde{x}}% \log p(\tilde{x}|x)-s_{\tilde{q}_{\theta}}(\tilde{x})||^{2}_{2}\mathop{}\!% \mathrm{d}{\tilde{x}}\mathop{}\!\mathrm{d}{x}≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∬ italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) | | ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) - italic_s start_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_d over~ start_ARG italic_x end_ARG roman_d italic_x
12p(x~|x)pd(x)x~xσ2+sq~θ(x~)22dx~dx,approaches-limitabsent12double-integral𝑝conditional~𝑥𝑥subscript𝑝𝑑𝑥superscriptsubscriptdelimited-∥∥~𝑥𝑥superscript𝜎2subscript𝑠subscript~𝑞𝜃~𝑥22differential-d~𝑥differential-d𝑥\displaystyle\doteq\frac{1}{2}\iint p(\tilde{x}|x)p_{d}(x)\left\lVert\frac{% \tilde{x}-x}{\sigma^{2}}+s_{\tilde{q}_{\theta}}(\tilde{x})\right\rVert_{2}^{2}% \mathop{}\!\mathrm{d}{\tilde{x}}\mathop{}\!\mathrm{d}{x}\>,≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∬ italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ divide start_ARG over~ start_ARG italic_x end_ARG - italic_x end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + italic_s start_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d over~ start_ARG italic_x end_ARG roman_d italic_x , (4)

where the last equation is due to x~logp(x~|x)subscript~𝑥𝑝conditional~𝑥𝑥\nabla_{\tilde{x}}\log p(\tilde{x}|x)∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) being tractable for the Gaussian distribution p(x~|x)𝑝conditional~𝑥𝑥p(\tilde{x}|x)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ).

Compared to the KL or SM objectives, the DSM objective is scalable and well-defined when the data distribution is singular111The singular distribution is not absolutely continuous (a.c.formulae-sequence𝑎𝑐a.c.italic_a . italic_c .) with respect to the Lebesgue measure, thus doesn’t allow a density function (Tao, 2011, p.172). A typical example is a data distribution supported on a lower-dimensional manifold. In this case, the KL divergence is ill-defined and cannot be used to train the models. When using DSM, the distribution after Gaussian convolution would always be a.c.formulae-sequence𝑎𝑐a.c.italic_a . italic_c ., thus can be a valid training objective, see Zhang et al. (2020) or Arjovsky et al. (2017) for a detailed introduction. Zhang et al. (2020) and can alleviate the blindness problem of score matching Song and Ermon (2019); Wenliang and Kanagawa (2020); Zhang et al. (2022b). On the other hand, there is a notable disadvantage associated with the DSM objective: for a fixed σ>0𝜎0\sigma>0italic_σ > 0, the DSM objective is not a consistent objective for learning the underlying data distribution pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT since FD(p~d||q~θ)=0p~d=q~θpd{\mathrm{FD}}(\tilde{p}_{d}||\tilde{q}_{\theta})=0\implies\tilde{p}_{d}=\tilde% {q}_{\theta}\neq p_{d}roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) = 0 ⟹ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≠ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. A common solution is to anneal σ0𝜎0\sigma\rightarrow 0italic_σ → 0 during training. However, Equation 4 is not defined when σ=0𝜎0\sigma=0italic_σ = 0 since the division in Equation 4 will make (x~x)/σ2~𝑥𝑥superscript𝜎2(\tilde{x}-x)/\sigma^{2}( over~ start_ARG italic_x end_ARG - italic_x ) / italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT unbounded, which results in an inconsistent objective. Annealing σ𝜎\sigmaitalic_σ increases the variance of the training gradients Song and Kingma (2021); Wang et al. (2020), which makes the optimization challenging in practice.

To overcome the challenges, we propose an alternative data generation scheme: we use DSM with a fixed σ>0𝜎0\sigma>0italic_σ > 0 to train a ‘noisy’ energy model and then construct a sampler which targets the underlying ‘clean’ model. Specifically, our contributions are summarized as follows:

  • We demonstrate that for an EBM that learns a noisy data distribution, there exists a unique underlying clean model which recovers the true data distribution.

  • We introduce a pseudo-Gibbs sampling scheme incorporating an analytical moment-matching approximation of the denoising distribution. This allows us to sample from the underlying clean model without requiring additional training.

  • We illustrate how to scale our method for high-dimensional data and demonstrate the generation of high-quality images using only a single level of fixed noise. Furthermore, we showcase the application of our proposed method in multi-level noise scenarios, closely resembling a diffusion model.

2 Clean Model Identification

For a fixed σ>0𝜎0\sigma>0italic_σ > 0, DSM can only learn a ‘noisy’ data distribution q~θ(x~)subscript~𝑞𝜃~𝑥\tilde{q}_{\theta}(\tilde{x})over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) even in the ideal case where the Fisher divergence is exactly minimized, since FD(p~d||q~θ*)=0p~d=q~θpd{\mathrm{FD}}(\tilde{p}_{d}||\tilde{q}_{\theta^{*}})=0\rightarrow\tilde{p}_{d}% =\tilde{q}_{\theta}\neq p_{d}roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) = 0 → over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≠ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. In this case, the following theorem shows that there exists a ‘clean’ model that is implicitly defined that learns the true data distribution.

Theorem 2.1 (Existence of the underlying clean model for optimal q~θ(x~)subscript~𝑞𝜃~𝑥\tilde{q}_{\theta}(\tilde{x})over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG )).

When the Fisher divergence goes to 0, FD(p~d||q~θ)=0p~d=q~θ{\mathrm{FD}}(\tilde{p}_{d}||\tilde{q}_{\theta})=0\rightarrow\tilde{p}_{d}=% \tilde{q}_{\theta}roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) = 0 → over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, there exists an unique underlying clean model q(x)𝑞𝑥q(x)italic_q ( italic_x ) such that q~θ(x~)=q(x)p(x~|x)dxsubscriptnormal-~𝑞𝜃normal-~𝑥𝑞𝑥𝑝conditionalnormal-~𝑥𝑥differential-d𝑥\tilde{q}_{\theta}(\tilde{x})=\int q(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∫ italic_q ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x and q(x)=pd(x)𝑞𝑥subscript𝑝𝑑𝑥q(x)=p_{d}(x)italic_q ( italic_x ) = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ).

See Appendix A.1 for proof. This theorem shows that despite training an EBM on noisy data, there is an implicit model within it that can recover the true data distribution. Therefore, instead of annealing the noise σ0𝜎0\sigma\rightarrow 0italic_σ → 0 to recover the true data distribution, we will demonstrate how to directly sample from the implicitly-defined clean model given the noisy energy-based model in the next section.

We want to highlight that the ‘perfect fit’ assumption, i.e. achieving FD(p~d||q~θ)=0{\mathrm{FD}}(\tilde{p}_{d}||\tilde{q}_{\theta})=0roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) = 0, may not hold for a complex data distribution pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT or underpowered EBMs. Therefore, we provide general sufficient conditions for the existence of the clean model for an imperfect EBM in Appendix A.2.

2.1 Gibbs Sampling with Gaussian Moment Matching

Given a well-trained noisy energy-based model q~θ(x~)=p~d(x~)subscript~𝑞𝜃~𝑥subscript~𝑝𝑑~𝑥\tilde{q}_{\theta}(\tilde{x})=\tilde{p}_{d}(\tilde{x})over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ), the clean model has the form

q(x)=p(x|x~)q~θ(x~)𝑑x~,𝑞𝑥𝑝conditional𝑥~𝑥subscript~𝑞𝜃~𝑥differential-d~𝑥\displaystyle q(x)=\int p(x|\tilde{x})\tilde{q}_{\theta}(\tilde{x})d\tilde{x},italic_q ( italic_x ) = ∫ italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) italic_d over~ start_ARG italic_x end_ARG , (5)

where the denoising distribution can be written as p(x|x~)p(x~|x)q(x)proportional-to𝑝conditional𝑥~𝑥𝑝conditional~𝑥𝑥𝑞𝑥p(x|\tilde{x})\propto p(\tilde{x}|x)q(x)italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ∝ italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q ( italic_x ). We notice that, since the noise distribution p(x~|x)=𝒩(x,σ2I)𝑝conditional~𝑥𝑥𝒩𝑥superscript𝜎2𝐼p(\tilde{x}|x)=\operatorname{\mathcal{N}}(x,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) is known, a Gibbs sampling scheme can be constructed to sample from the underlying clean model if we know the denoising distribution p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ), with

x~k1p(x~|x=xk1),xkp(x|x~=x~k1),formulae-sequencesimilar-tosubscript~𝑥𝑘1𝑝conditional~𝑥𝑥subscript𝑥𝑘1similar-tosubscript𝑥𝑘𝑝conditional𝑥~𝑥subscript~𝑥𝑘1\displaystyle\tilde{x}_{k-1}\sim p(\tilde{x}|x=x_{k-1}),\quad x_{k}\sim p(x|% \tilde{x}=\tilde{x}_{k-1}),over~ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ∼ italic_p ( over~ start_ARG italic_x end_ARG | italic_x = italic_x start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ) , italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∼ italic_p ( italic_x | over~ start_ARG italic_x end_ARG = over~ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ) , (6)

where the initial sample x0p0(x)similar-tosubscript𝑥0subscript𝑝0𝑥x_{0}\sim p_{0}(x)italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) can be drawn from a standard Gaussian p0(x)=𝒩(0,I)subscript𝑝0𝑥𝒩0𝐼p_{0}(x)=\mathcal{N}(0,I)italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) = caligraphic_N ( 0 , italic_I ). However, as the denoising distribution p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) is usually intractable for complex pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, we propose an analytical Gaussian moment matching approximation of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ).

Denote the mean and covariance of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) as

μ(x~)=xp(x|x~),Σ(x~)=x2p(x|x~)xp(x|x~)2.formulae-sequence𝜇~𝑥subscriptdelimited-⟨⟩𝑥𝑝conditional𝑥~𝑥Σ~𝑥subscriptdelimited-⟨⟩superscript𝑥2𝑝conditional𝑥~𝑥subscriptsuperscriptdelimited-⟨⟩𝑥2𝑝conditional𝑥~𝑥\displaystyle\mu(\tilde{x})=\langle x\rangle_{p(x|\tilde{x})},\quad\Sigma(% \tilde{x})=\langle x^{2}\rangle_{p(x|\tilde{x})}-\langle x\rangle^{2}_{p(x|% \tilde{x})}.italic_μ ( over~ start_ARG italic_x end_ARG ) = ⟨ italic_x ⟩ start_POSTSUBSCRIPT italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT , roman_Σ ( over~ start_ARG italic_x end_ARG ) = ⟨ italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT - ⟨ italic_x ⟩ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT . (7)

The classic Gaussian moment matching method Minka (2013) specifies a Gaussian approximation p(x|x~)𝒩(μ(x~),Σ(x~))𝑝conditional𝑥~𝑥𝒩𝜇~𝑥Σ~𝑥p(x|\tilde{x})\approx\mathcal{N}(\mu(\tilde{x}),\Sigma(\tilde{x}))italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ≈ caligraphic_N ( italic_μ ( over~ start_ARG italic_x end_ARG ) , roman_Σ ( over~ start_ARG italic_x end_ARG ) ), which matches the first and second moment of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ). When q~θ=p~dsubscript~𝑞𝜃subscript~𝑝𝑑\tilde{q}_{\theta}=\tilde{p}_{d}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, the first mean of the denoised distribution has a well-known analytical form Song et al. (2021); Bao et al. (2022); Efron (2011); Robbins (1992)

μ(x~)=x~+σ2sq~θ(x~);𝜇~𝑥~𝑥superscript𝜎2subscript𝑠subscript~𝑞𝜃~𝑥\displaystyle\mu(\tilde{x})=\tilde{x}+\sigma^{2}s_{\tilde{q}_{\theta}}(\tilde{% x});italic_μ ( over~ start_ARG italic_x end_ARG ) = over~ start_ARG italic_x end_ARG + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ; (8)

we include the derivation in Appendix A.3. Using this identity, we can rewrite Equation 4 as

FD(p~d||q~θ)12σ4p(x~|x)pd(x)xμ(x~)22dx~dx,\displaystyle\mathrm{FD}(\tilde{p}_{d}||\tilde{q}_{\theta})\doteq\frac{1}{2% \sigma^{4}}\iint p(\tilde{x}|x)p_{d}(x)\left\lVert x-\mu(\tilde{x})\right% \rVert_{2}^{2}\mathop{}\!\mathrm{d}{\tilde{x}}\mathop{}\!\mathrm{d}{x},roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ≐ divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG ∬ italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_x - italic_μ ( over~ start_ARG italic_x end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d over~ start_ARG italic_x end_ARG roman_d italic_x , (9)

where we can see that the Fisher divergence only depends on μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ). Since FD(p~d||q~θ)=0q=pd\mathrm{FD}(\tilde{p}_{d}||\tilde{q}_{\theta})=0\Leftrightarrow q=p_{d}roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) = 0 ⇔ italic_q = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT (Theorem 2.1), the function μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) fully characterizes the distribution q𝑞qitalic_q. Therefore, μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) and p(x~|x)=𝒩(x,σ2I)𝑝conditional~𝑥𝑥𝒩𝑥superscript𝜎2𝐼p(\tilde{x}|x)=\operatorname{\mathcal{N}}(x,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) can provide sufficient information to determine p(x|x~)q(x)p(x~|x)proportional-to𝑝conditional𝑥~𝑥𝑞𝑥𝑝conditional~𝑥𝑥p(x|\tilde{x})\propto q(x)p(\tilde{x}|x)italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ∝ italic_q ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ). As a consequence, the following theorem shows that the covariance function can also be analytically derived.

Theorem 2.2 (Analytical Covariance Identity).

Given a clean model q(x)𝑞𝑥q(x)italic_q ( italic_x ) such that q(x)p(x~|x)dx=q~θ(x~)=p~d(x~)𝑞𝑥𝑝conditionalnormal-~𝑥𝑥differential-d𝑥subscriptnormal-~𝑞𝜃normal-~𝑥subscriptnormal-~𝑝𝑑normal-~𝑥\int q(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}=\tilde{q}_{\theta}(\tilde{x})=% \tilde{p}_{d}(\tilde{x})∫ italic_q ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) with p(x~|x)=𝒩(x,σ2I)𝑝conditionalnormal-~𝑥𝑥𝒩𝑥superscript𝜎2𝐼p(\tilde{x}|x)=\operatorname{\mathcal{N}}(x,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ), the μ(x~)𝜇normal-~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) and Σ(x~)normal-Σnormal-~𝑥\Sigma(\tilde{x})roman_Σ ( over~ start_ARG italic_x end_ARG ) of the p(x|x~)q(x)p(x~|x)proportional-to𝑝conditional𝑥normal-~𝑥𝑞𝑥𝑝conditionalnormal-~𝑥𝑥p(x|\tilde{x})\propto q(x)p(\tilde{x}|x)italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ∝ italic_q ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) has the following relations

Σ(x~)=σ2x~μ(x~)=σ4x~2logq~θ(x~)+σ2I.Σ~𝑥superscript𝜎2subscript~𝑥𝜇~𝑥superscript𝜎4superscriptsubscript~𝑥2subscript~𝑞𝜃~𝑥superscript𝜎2𝐼\displaystyle\Sigma(\tilde{x})=\sigma^{2}\nabla_{\tilde{x}}\mu(\tilde{x})=% \sigma^{4}\nabla_{\tilde{x}}^{2}\log\tilde{q}_{\theta}(\tilde{x})+\sigma^{2}I.roman_Σ ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_μ ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I . (10)

See Appendix A.3 for proof. This analytical covariance identity can be seen as a high-dimensional generalization of the 2nd-order Tweedie’s Formula Efron (2011); Robbins (1992). Therefore, the analytical full-covariance moment matching approximation can be written as

p(x|x~)𝒩(x~+σ2x~logq~θ(x~),σ4x~2logq~θ(x~)+σ2I).𝑝conditional𝑥~𝑥𝒩~𝑥superscript𝜎2subscript~𝑥subscript~𝑞𝜃~𝑥superscript𝜎4superscriptsubscript~𝑥2subscript~𝑞𝜃~𝑥superscript𝜎2𝐼\displaystyle p(x|\tilde{x})\approx\operatorname{\mathcal{N}}(\tilde{x}+\sigma% ^{2}\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x}),\sigma^{4}\nabla_{% \tilde{x}}^{2}\log\tilde{q}_{\theta}(\tilde{x})+\sigma^{2}I).italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ≈ caligraphic_N ( over~ start_ARG italic_x end_ARG + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) . (11)

We want to highlight that since the Gaussian moment matching is only an approximation of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ), the sampling scheme in Equation 6 is a ‘pseudo’ Gibbs sampler unless the true p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) is also a Gaussian distribution222For example, when pd(x)=𝒩(μd,σd)subscript𝑝𝑑𝑥𝒩subscript𝜇𝑑subscript𝜎𝑑p_{d}(x)=\operatorname{\mathcal{N}}(\mu_{d},\sigma_{d})italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ), the true posterior p(x|x~)pd(x)p(x~|x)proportional-to𝑝conditional𝑥~𝑥subscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥p(x|\tilde{x})\propto p_{d}(x)p(\tilde{x}|x)italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ∝ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) will be a Gaussian with mean μ(x~)=(σ2μd+σd2x~)/(σ2+σd2)𝜇~𝑥superscript𝜎2subscript𝜇𝑑superscriptsubscript𝜎𝑑2~𝑥superscript𝜎2superscriptsubscript𝜎𝑑2\mu(\tilde{x})=(\sigma^{2}\mu_{d}+\sigma_{d}^{2}\tilde{x})/(\sigma^{2}+\sigma_% {d}^{2})italic_μ ( over~ start_ARG italic_x end_ARG ) = ( italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_μ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over~ start_ARG italic_x end_ARG ) / ( italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) and variance Σ(x~)=σd2σ2/(σ2+σd2)Σ~𝑥superscriptsubscript𝜎𝑑2superscript𝜎2superscript𝜎2superscriptsubscript𝜎𝑑2\Sigma(\tilde{x})=\sigma_{d}^{2}\sigma^{2}/(\sigma^{2}+\sigma_{d}^{2})roman_Σ ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), we can verify that Σ(x~)=σ2x~μ(x~)Σ~𝑥superscript𝜎2subscript~𝑥𝜇~𝑥\Sigma(\tilde{x})=\sigma^{2}\nabla_{\tilde{x}}\mu(\tilde{x})roman_Σ ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_μ ( over~ start_ARG italic_x end_ARG )., which is not true for general non-Gaussian pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. However, since μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) and p(x~|x)=𝒩(x,σ2I)𝑝conditional~𝑥𝑥𝒩𝑥superscript𝜎2𝐼p(\tilde{x}|x)=\operatorname{\mathcal{N}}(x,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) are already sufficient to specify p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ), it should be possible to derive expressions for higher-order moments which themselves involve only μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) and σ𝜎\sigmaitalic_σ; we leave this to future work. To our knowledge, the x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-conditioned full covariance Gaussian moment matching approximation to p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) has not been derived previously. In the next section, we briefly discuss the connections between our method and other related approaches.

2.2 Connection to Covariance Learning Approaches

Bengio et al. (2013) proposes to approximate the true posterior p(x|x~)pd(x)p(x~|x)proportional-to𝑝conditional𝑥~𝑥subscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥p(x|\tilde{x})\propto p_{d}(x)p(\tilde{x}|x)italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ∝ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) with a variational distribution qθ(x|x~)subscript𝑞𝜃conditional𝑥~𝑥q_{\theta}(x|\tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ). The parameter θ𝜃\thetaitalic_θ is then learned by minimizing the joint KL divergence

KL(p(x~|x)pd(x)qθ(x|x~))p~d(x~))pd(x)p(x~|x)logqθ(x|x~)dx~dx,\displaystyle{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q_{\theta}(x|\tilde{x}% ))\tilde{p}_{d}(\tilde{x}))\doteq-\iint p_{d}(x)p(\tilde{x}|x)\log q_{\theta}(% x|\tilde{x})\mathop{}\!\mathrm{d}{\tilde{x}}\mathop{}\!\mathrm{d}{x},roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) ≐ - ∬ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_log italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) roman_d over~ start_ARG italic_x end_ARG roman_d italic_x , (12)

where p~d(x~)=pd(x)p(x~|x)dxsubscript~𝑝𝑑~𝑥subscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥differential-d𝑥\tilde{p}_{d}(\tilde{x})=\int p_{d}(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x. The joint KL divergence in Equation 12 encourages qθ(x|x~)subscript𝑞𝜃conditional𝑥~𝑥q_{\theta}(x|\tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) to match the moments of the true posterior p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ), and defines an upper bound of the marginal KL Zhang et al. (2019)

KL(p(x~|x)pd(x)qθ(x|x~))p~d(x~))KL(pd(x)||qθ(x)),\displaystyle{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q_{\theta}(x|\tilde{x}% ))\tilde{p}_{d}(\tilde{x}))\geq{\mathrm{KL}}(p_{d}(x)||q_{\theta}(x)),roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) ≥ roman_KL ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ) , (13)

where the model is implicitly defined as the marginal of the joint qθ(x)=qθ(x|x~)p~d(x~)subscript𝑞𝜃𝑥subscript𝑞𝜃conditional𝑥~𝑥subscript~𝑝𝑑~𝑥q_{\theta}(x)=\int q_{\theta}(x|\tilde{x})\tilde{p}_{d}(\tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = ∫ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ). When qθ(x|x~)subscript𝑞𝜃conditional𝑥~𝑥q_{\theta}(x|\tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) is a consistent estimator of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ), this asymptotic distribution of the Gibbs sampling will converge to the true data distribution pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) Bengio et al. (2013). For continuous data, the variational distribution is chosen as a Gaussian distribution qθ(x|x~)=𝒩(μθ(x~),Σθ(x~))subscript𝑞𝜃conditional𝑥~𝑥𝒩subscript𝜇𝜃~𝑥subscriptΣ𝜃~𝑥q_{\theta}(x|\tilde{x})=\operatorname{\mathcal{N}}(\mu_{\theta}(\tilde{x}),% \Sigma_{\theta}(\tilde{x}))italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ), where the mean μθ()subscript𝜇𝜃\mu_{\theta}(\cdot)italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) and the covariance Σθ()subscriptΣ𝜃\Sigma_{\theta}(\cdot)roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) are parameterized by neural networks. We note that the only difference between the KL and DSM objective (Equation 9) is that the KL objective additionally learns the covariance. We thus show that the optimal covariance under KL minimization is the proposed analytical covariance.

Theorem 2.3 (Optimal Gaussian Approximation).

Let p(x~|x)=𝒩(0,σ2I)𝑝conditionalnormal-~𝑥𝑥𝒩0superscript𝜎2𝐼p(\tilde{x}|x)=\operatorname{\mathcal{N}}(0,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) and assume Gaussian distribution qθ(x|x~)=𝒩(μq(x~),Σq(x~))subscript𝑞𝜃conditional𝑥normal-~𝑥𝒩subscript𝜇𝑞normal-~𝑥subscriptnormal-Σ𝑞normal-~𝑥q_{\theta}(x|\tilde{x})=\operatorname{\mathcal{N}}(\mu_{q}(\tilde{x}),\Sigma_{% q}(\tilde{x}))italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ), then the optimal q*superscript𝑞q^{*}italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT such that

q*=argminqKL(p(x~|x)pd(x)q(x|x~)p~d(x~))\displaystyle q^{*}=\arg\min_{q}{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q(x% |\tilde{x})\tilde{p}_{d}(\tilde{x}))italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = roman_arg roman_min start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) (14)

has the mean and covariance with the form

μq*(x~)=xp(x|x~),Σq*(x~)=σ2x~μq*(x~),formulae-sequencesuperscriptsubscript𝜇𝑞~𝑥subscriptdelimited-⟨⟩𝑥𝑝conditional𝑥~𝑥superscriptsubscriptΣ𝑞~𝑥superscript𝜎2subscript~𝑥superscriptsubscript𝜇𝑞~𝑥\displaystyle\mu_{q}^{*}(\tilde{x})=\langle x\rangle_{p(x|\tilde{x})},\quad% \Sigma_{q}^{*}(\tilde{x})=\sigma^{2}\nabla_{\tilde{x}}\mu_{q}^{*}(\tilde{x}),italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( over~ start_ARG italic_x end_ARG ) = ⟨ italic_x ⟩ start_POSTSUBSCRIPT italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( over~ start_ARG italic_x end_ARG ) , (15)

see Appendix A.4 for proof. Therefore, when the optimal mean function is learned μθ(x~)=μq*(x~)subscript𝜇𝜃~𝑥superscriptsubscript𝜇𝑞~𝑥\mu_{\theta}(\tilde{x})=\mu_{q}^{*}(\tilde{x})italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( over~ start_ARG italic_x end_ARG ), the optimal Σ(x~)Σ~𝑥\Sigma(\tilde{x})roman_Σ ( over~ start_ARG italic_x end_ARG ) can be analytically derived, making the learning of Σθ(x~)subscriptΣ𝜃~𝑥\Sigma_{\theta}(\tilde{x})roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) redundant. In addition to the training inefficiency caused by more parameters, the amortized covariance network may suffer from poor generalization Zhang et al. (2022a). Moreover, the KL objective is also not well-defined for learning data distributions which lie on a low-dimensional manifold, e.g. MNIST, see Section 4 for a detailed discussion. In this case, the learned Σθ(x~)subscriptΣ𝜃~𝑥\Sigma_{\theta}(\tilde{x})roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) may be a degenerate matrix, making the Gaussian density function q(x|x~)𝑞conditional𝑥~𝑥q(x|\tilde{x})italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) ill-defined Zhang et al. (2020) which impedes the training, see Figure 5 for an example.

Paper Meng et al. (2021) proposes a higher-order score-matching loss to simultaneously learn both the first order score x~logq~θ(x~)subscript~𝑥subscript~𝑞𝜃~𝑥\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) and the second order score x~2logq~θ(x~)subscriptsuperscript2~𝑥subscript~𝑞𝜃~𝑥\nabla^{2}_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ). However, our findings indicate that the mean function μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) (or the first order score x~logp(x~)subscript~𝑥𝑝~𝑥\nabla_{\tilde{x}}\log p(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log italic_p ( over~ start_ARG italic_x end_ARG )) already contains all the moment information of the underlying true distribution pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, and the optimal moment can be derived using the mean function. Therefore, learning the second-order score is redundant and may lead to sub-optimal inference.

2.3 Connection to Analytic DDPM

The recent paper Bao et al. (2022) considers a constrained variational family qθ(x|x~)=𝒩(μθ(x~),σq2I)subscript𝑞𝜃conditional𝑥~𝑥𝒩subscript𝜇𝜃~𝑥superscriptsubscript𝜎𝑞2𝐼q_{\theta}(x|\tilde{x})=\operatorname{\mathcal{N}}(\mu_{\theta}(\tilde{x}),% \sigma_{q}^{2}I)italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) in the context of diffusion model and derive the optimal σq*superscriptsubscript𝜎𝑞\sigma_{q}^{*}italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT as

σq*2superscriptsubscript𝜎𝑞absent2\displaystyle\sigma_{q}^{*2}italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * 2 end_POSTSUPERSCRIPT =argminσqKL(p(x~|x)pd(x)qθ(x|x~))p~d(x~))=1dTr(Covq(x|x~)[x])p~d(x~),\displaystyle=\arg\min_{\sigma_{q}}{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q% _{\theta}(x|\tilde{x}))\tilde{p}_{d}(\tilde{x}))=\frac{1}{d}\left\langle% \mathrm{Tr}\left(\mathrm{Cov}_{q(x|\tilde{x})}[x]\right)\right\rangle_{\tilde{% p}_{d}(\tilde{x})},= roman_arg roman_min start_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) = divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ roman_Tr ( roman_Cov start_POSTSUBSCRIPT italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT [ italic_x ] ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT , (16)

which can also be rewritten using the score function

σq*2=σ2σ4/dsqθ(x~)22p~d(x~).superscriptsubscript𝜎𝑞absent2superscript𝜎2superscript𝜎4𝑑subscriptdelimited-⟨⟩superscriptsubscriptdelimited-∥∥subscript𝑠subscript𝑞𝜃~𝑥22subscript~𝑝𝑑~𝑥\displaystyle\sigma_{q}^{*2}=\sigma^{2}-\sigma^{4}/d\left\langle\left\lVert s_% {q_{\theta}}(\tilde{x})\right\rVert_{2}^{2}\right\rangle_{\tilde{p}_{d}(\tilde% {x})}.italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * 2 end_POSTSUPERSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT / italic_d ⟨ ∥ italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT . (17)

In Appendix B, we provide a detailed derivation to show how this approximation can be linked to our method using the Fisher information identity Fisher (1925). This approximation has two potential limitations: first, compared to full covariance moment matching, the assumed isotropic covariance structure may be insufficiently flexible to capture the true posterior; second, the covariance is independent of x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG.

Refer to caption pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) Refer to caption p~d(x~)subscript~𝑝𝑑~𝑥\tilde{p}_{d}(\tilde{x})over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) pd(x)p(x~|x)dxsubscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥differential-d𝑥\int p_{d}(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x
(a) Gaussian mixtures visualizations
Refer to caption
(b) Density of p(x|x~)𝑝conditional𝑥superscript~𝑥p(x|\tilde{x}^{\prime})italic_p ( italic_x | over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) (blue) and four x~superscript~𝑥\tilde{x}^{\prime}over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT points (green)
Refer to caption
(c) Analytical full covariance moment matching
Refer to caption
(d) Analytical isotropic covariance moment matching
Refer to caption
(e) Learned diagonal covariance
Figure 1: Figure (a) shows the clean data distribution pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) and the corresponding noisy distribution p~d(x~)subscript~𝑝𝑑~𝑥\tilde{p}_{d}(\tilde{x})over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ). Figure (b) shows 4 conditioned samples in the noisy space. Figures (c, d, e) visualize the true posterior p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) (green) and three posterior approximations (orange). We find that only the proposed x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent analytical full-covariance moment matching can capture the variance of the true posterior, whereas the other two methods underestimate the variance.

The second assumption only holds when μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) is a linear function of x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG333Since when μ(x~)𝜇~𝑥\mu(\tilde{x})italic_μ ( over~ start_ARG italic_x end_ARG ) is a linear function of x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG, using Theorem 2.2, we have Σ(x~)=σ2x~μ(x~)Σ~𝑥superscript𝜎2subscript~𝑥𝜇~𝑥\Sigma(\tilde{x})=\sigma^{2}\nabla_{\tilde{x}}\mu(\tilde{x})roman_Σ ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_μ ( over~ start_ARG italic_x end_ARG ) will not depend on x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG, see also footnote 1 for an example. (e.g. when pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) is Gaussian) and does not hold for other non-Gaussian pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ). Therefore, our x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent full-covariance approximation offers a more versatile approximation family, which ultimately results in a more precise estimation. However, in certain applications such as accelerating the sampling procedure of a diffusion model Bao et al. (2022), it is advantageous to use a x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent isotropic covariance due to its inexpensive estimation. On the other hand, our x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent covariance necessitates the computation of the Hessian for each x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG, making it inefficient for high-dimensional data. In Section 3, we will explore approaches to mitigate this limitation.

2.4 Posterior Approximation Comparison

We now consider a toy example to compare the three denoising posterior approximations discussed above. Let pd(x)subscript𝑝𝑑𝑥p_{d}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) be a Mixture of Gaussians (MoG) pd(x)=14k=1k=4gk(x)subscript𝑝𝑑𝑥14superscriptsubscript𝑘1𝑘4subscript𝑔𝑘𝑥p_{d}(x)=\frac{1}{4}\sum_{k=1}^{k=4}g_{k}(x)italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) = divide start_ARG 1 end_ARG start_ARG 4 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k = 4 end_POSTSUPERSCRIPT italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ) whose components g[1:4]subscript𝑔delimited-[]:14g_{[1:4]}italic_g start_POSTSUBSCRIPT [ 1 : 4 ] end_POSTSUBSCRIPT are 2D Gaussians with means [1,1],[1,1],[1,1],[1,1]11111111[-1,-1],[-1,1],[1,1],[1,-1][ - 1 , - 1 ] , [ - 1 , 1 ] , [ 1 , 1 ] , [ 1 , - 1 ] and isotropic covariance σg2Isuperscriptsubscript𝜎𝑔2𝐼\sigma_{g}^{2}Iitalic_σ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I with σg=0.2subscript𝜎𝑔0.2\sigma_{g}=0.2italic_σ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = 0.2. The noise distribution is p(x~|x)=𝒩(x,σ2I)𝑝conditional~𝑥𝑥𝒩𝑥superscript𝜎2𝐼p(\tilde{x}|x)=\mathcal{N}(x,\sigma^{2}I)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) with σ=0.2𝜎0.2\sigma=0.2italic_σ = 0.2, so p~(x~)=p(x)p(x~|x)dx~𝑝~𝑥𝑝𝑥𝑝conditional~𝑥𝑥differential-d𝑥\tilde{p}(\tilde{x})=\int p(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}over~ start_ARG italic_p end_ARG ( over~ start_ARG italic_x end_ARG ) = ∫ italic_p ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x is an MoG with the same component means and diagonal covariance (σg2+σ2)Isuperscriptsubscript𝜎𝑔2superscript𝜎2𝐼(\sigma_{g}^{2}+\sigma^{2})I( italic_σ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_I; see Figure 0(a) for a visualization. In this case the true posterior p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) does not allow a tractable form. Fortunately, given a noisy sample x~superscript~𝑥\tilde{x}^{\prime}over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and an evaluation point xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we can evaluate the true density p(x|x~)𝑝conditional𝑥superscript~𝑥p(x|\tilde{x}^{\prime})italic_p ( italic_x | over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) using Bayes rule: p(x|x~=x~)=p(x~|x=x)pd(x)/p~d(x~)𝑝conditionalsuperscript𝑥~𝑥superscript~𝑥𝑝conditionalsuperscript~𝑥𝑥superscript𝑥subscript𝑝𝑑superscript𝑥subscript~𝑝𝑑superscript~𝑥p(x^{\prime}|\tilde{x}=\tilde{x}^{\prime})=p(\tilde{x}^{\prime}|x=x^{\prime})p% _{d}(x^{\prime})/\tilde{p}_{d}(\tilde{x}^{\prime})italic_p ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | over~ start_ARG italic_x end_ARG = over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = italic_p ( over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | italic_x = italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) / over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ). Figure 1 shows the true posteriors given four different x~superscript~𝑥\tilde{x}^{\prime}over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT where we use grid data in x𝑥xitalic_x-space to visualize the density.

To train the model, we sample 10,000 data points from pdsubscript𝑝𝑑p_{d}italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT as our training data. For the KL-trained Gibbs sampler described in Section 2.2, we use a network with 3 hidden layers with 400 hidden units, Swish activation Ramachandran et al. (2017) and output size 4 to generate both mean and log standard deviation of the Gaussian approximation. For the moment-matching Gibbs sampler (including both full and isotropic covariance), we use the same network architecture but with output size 1 to get the scalar energy and DSM as the training objective. Both networks are trained with batch size 100 and Adam Kingma and Ba (2014) optimizer with learning rate 1×1041superscript1041{\times}10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT for 100 epochs. For the x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent isotropic covariance, we use the Monte Carlo approximation to estimate the variance Bao et al. (2022) with 10000 samples from p~d(x~)subscript~𝑝𝑑~𝑥\tilde{p}_{d}(\tilde{x})over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ).

Table 1: MMD evaluations of a single chain
Data Learn diag. Analytic iso. Analytic full
MoGs 0.929±0.343plus-or-minus0.9290.3430.929\pm 0.3430.929 ± 0.343 0.724±0.361plus-or-minus0.7240.3610.724\pm 0.3610.724 ± 0.361 0.305±0.141plus-or-minus0.3050.141\textbf{0.305}\pm 0.1410.305 ± 0.141
Rings 0.364±0.044plus-or-minus0.3640.0440.364\pm 0.0440.364 ± 0.044 0.006±0.002plus-or-minus0.0060.0020.006\pm 0.0020.006 ± 0.002 0.005±0.001plus-or-minus0.0050.001\textbf{0.005}\pm 0.0010.005 ± 0.001
Roll 0.053±0.011plus-or-minus0.0530.0110.053\pm 0.0110.053 ± 0.011 0.030±0.001plus-or-minus0.0300.0010.030\pm 0.0010.030 ± 0.001 0.016±0.002plus-or-minus0.0160.002\textbf{0.016}\pm 0.0020.016 ± 0.002

Figure 1 visualizes the approximations to the denoising posterior p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) estimated by each of the three methods described in the previous sections. We surprisingly find that although the KL objective in Equation 12 encourages qθ(x|x~)subscript𝑞𝜃conditional𝑥~𝑥q_{\theta}(x|\tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) to match the moments of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ), the learned covariance in Figure 0(e) still underestimates the variance of the posterior. This shows the redundancy of covariance learning can degrade the variational approximation performance. Additionally, the x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent covariance fails to account for the relative positions of x~superscript~𝑥\tilde{x}^{\prime}over~ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and lacks the ability to predict the posterior’s elliptical shape due to its isotropic nature. In contrast, our x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent full covariance approximation overcomes these limitations, enabling more accurate predictions that capture the intricate geometry of the posterior distribution.

Refer to caption
(a) True data
Refer to caption
(b) Learn. diag.
Refer to caption
(c) Analytic iso.
Refer to caption
(d) Analytic full
Refer to caption
(e) True data
Refer to caption
(f) Learn. diag.
Refer to caption
(g) Analytic iso.
Refer to caption
(h) Analytic full
Refer to caption
(i) True data
Refer to caption
(j) Learn. diag.
Refer to caption
(k) Analytic iso.
Refer to caption
(l) Analytic full
Figure 2: Samples from a single chain Gibbs sampling

We then use the estimated posterior to conduct (pseudo) Gibbs sampling to generate samples444The code of the experiments can be found in https://github.com/zmtomorrow/MMDGS_NeurIPS.. Specifically, we initialize the first sample x0𝒩(0,0.1)similar-tosubscript𝑥0𝒩00.1x_{0}\sim\operatorname{\mathcal{N}}(0,0.1)italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , 0.1 ) and run one Markov Chain with 10,000 time steps to generate 10,000 samples. In addition to the mixture of Gaussian datasets, we also train and generate samples from the 2D Swiss roll and two-ring datasets. For numerical evaluation, we calculate the Maximum Mean Discrepancy (MMD) Gretton et al. (2012) between 10k samples generated by a single-chain Gibbs sampler and 10k samples from the training dataset respectively. The kernel insides MMD is a sum over 5 Gaussian kernels with bandwidth ranging over [22,21,20,21,22]superscript22superscript21superscript20superscript21superscript22[2^{-2},2^{-1},2^{0},2^{1},2^{2}][ 2 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , 2 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]. The MMD results (including both mean and std) are calculated using 5 random seeds. We find that Gibbs sampling with the proposed analytical full covariance achieves the best results; numerical results are in Table 1, with a visual comparison in Figure 2.

3 Scalable Implementations for Image Data 

Scalable Diagonal Hessian Approximation As we discussed in Section 2.3, the proposed full covariance Gaussian approximation in Equation 10 requires calculating an D×D𝐷𝐷D{\times}Ditalic_D × italic_D Hessian x~2logq~(x~)subscriptsuperscript2~𝑥~𝑞~𝑥\nabla^{2}_{\tilde{x}}\log\tilde{q}(\tilde{x})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG ( over~ start_ARG italic_x end_ARG ) for each x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG with size D𝐷Ditalic_D, which brings both memory and computation difficulties for high-dimensional data. A naive diagonal Hessian method (only using the diagonal entries in the Hessian) will address the memory bottleneck but still needs D𝐷Ditalic_D times backward passes for the exact computation of the diagonal term Martens et al. (2012). In this paper, we use the following diagonal Hessian approximation Bekas et al. (2007),

Diag(H)1/Ss=1SvsHvs,Diag𝐻1𝑆superscriptsubscript𝑠1𝑆direct-productsubscript𝑣𝑠𝐻subscript𝑣𝑠\displaystyle\mathrm{Diag}(H)\approx 1/S\sum\nolimits_{s=1}^{S}v_{s}\odot Hv_{% s},roman_Diag ( italic_H ) ≈ 1 / italic_S ∑ start_POSTSUBSCRIPT italic_s = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ⊙ italic_H italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , (18)

where vsp(v)similar-tosubscript𝑣𝑠𝑝𝑣v_{s}\sim p(v)italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ∼ italic_p ( italic_v ) is a Rademacher random variable with entries ±1plus-or-minus1\pm 1± 1 and direct-product\odot denotes the element-wise product555This estimation should be distinguished from the Hutchinson’s Trace estimation Hutchinson (1990): Tr(H)1Ss=1SvsTHvsTr𝐻1𝑆superscriptsubscript𝑠1𝑆subscriptsuperscript𝑣𝑇𝑠𝐻subscript𝑣𝑠\mathrm{Tr}(H)\approx\frac{1}{S}\sum_{s=1}^{S}v^{T}_{s}Hv_{s}roman_Tr ( italic_H ) ≈ divide start_ARG 1 end_ARG start_ARG italic_S end_ARG ∑ start_POSTSUBSCRIPT italic_s = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT italic_H italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT where a dot-product is used between vssubscript𝑣𝑠v_{s}italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT and Hvs𝐻subscript𝑣𝑠Hv_{s}italic_H italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT.. This estimator will converge to the exact Hessian diagonals when S𝑆S\rightarrow\inftyitalic_S → ∞ Bekas et al. (2007). The computation for each vssubscript𝑣𝑠v_{s}italic_v start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT can be computed by two forward-backward passes. It is worth emphasizing that our x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent diagonal moment matching approach provides a comparable level of flexibility to the variational method proposed in Bengio et al. (2013) while eliminating the need for additional training of the diagonal covariance. Furthermore, our method remains more flexible than the isotropic x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent moment matching method proposed by Bao et al. (2022).

Energy or Score Parameterization For the full-covariance moment matching in Equation 10, we require x~2logq~θ(x~)subscriptsuperscript2~𝑥subscript~𝑞𝜃~𝑥\nabla^{2}_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) to be symmetric to obtain a valid Gaussian approximation. However, if we learn the score function x~logp(x~)=sθ(x~)subscript~𝑥𝑝~𝑥subscript𝑠𝜃~𝑥\nabla_{\tilde{x}}\log p(\tilde{x})=s_{\theta}(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log italic_p ( over~ start_ARG italic_x end_ARG ) = italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) using a network sθ():DD:subscript𝑠𝜃superscript𝐷superscript𝐷s_{\theta}(\cdot):{\mathbb{R}}^{D}\rightarrow{\mathbb{R}}^{D}italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, its Jacobian is not guaranteed to be symmetric. In this case, we follow Saremi et al. (2018) and directly parameterize the density function q~θ()subscript~𝑞𝜃\tilde{q}_{\theta}(\cdot)over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) with a neural network fθ():D:subscript𝑓𝜃superscript𝐷f_{\theta}(\cdot):{\mathbb{R}}^{D}\rightarrow{\mathbb{R}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R and let the score function x~logq~θ(x~)=xfθ(x~)subscript~𝑥subscript~𝑞𝜃~𝑥subscript𝑥subscript𝑓𝜃~𝑥\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})=-\nabla_{x}f_{\theta}(% \tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ). This can be obtained by AutoDiff packages like PyTorch Paszke et al. (2017), and this parameterization guarantees x~2fθ(x~)subscriptsuperscript2~𝑥subscript𝑓𝜃~𝑥\nabla^{2}_{\tilde{x}}f_{\theta}(\tilde{x})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) to be symmetric. We also notice that when using the diagonal Hessian approximation (Equation 18), we only need entries in Diag(H)Diag𝐻\mathrm{Diag}(H)roman_Diag ( italic_H ) to be positive in order to obtain a valid Gaussian approximation. In this case, the score parameterization remains applicable and offers more efficient training compared to the energy parameterization. Therefore, the combination of full/diagonal covariance and energy/score parameterization provides a tradeoff between flexibility and inference speed, allowing for a flexible approach while maintaining computational efficiency during training.

4 Image Generation with a Single Noise Level

Refer to caption
(a) Diagonal x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent covariance moment matching (Ours)
Refer to caption
(b) Isotropic x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent covariance moment matching Bao et al. (2022)
Refer to caption
(c) Diagonal covariance learned by KL minimization Bengio et al. (2013)
Figure 3: Figures (a,b,c) show the MNIST experiment comparisons, where we compare samples generated by pseudo-Gibbs sampling with three different q(x|x~)𝑞conditional𝑥~𝑥q(x|\tilde{x})italic_q ( italic_x | over~ start_ARG italic_x end_ARG ). We plot samples from 25 independent Markov Chains with t{0,1,5,10,20}𝑡0151020t\in\{0,1,5,10,20\}italic_t ∈ { 0 , 1 , 5 , 10 , 20 } time steps. We can find the samples generated by the proposed analytical covariance moment matching with diagonal approximation achieved the best sample quality.

We then apply the proposed method to model the grey-scale MNIST LeCun (1998) dataset. We use the standard U-Net architecture Song and Ermon (2019); Ronneberger et al. (2015) with a single fixed noise level σ=0.5𝜎0.5\sigma=0.5italic_σ = 0.5; the effect of varying σ𝜎\sigmaitalic_σ is explored in Appendix C.1. For the KL training objective, the output channel size is 2 to generate both mean and log-std at the same time. For DSM training, we take the sum of the U-Net output to obtain the scalar energy evaluation which also relates to the product-of-experts model described in Saremi et al. (2018). We train both networks for 300 epochs with learning rate 1×1041superscript1041{\times}10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and batch-size 100.

Refer to caption
(a) x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dep diag (Ours)
Refer to caption
(b) x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dep diag. (KL)
Refer to caption
(c) x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-ind iso.
Figure 4: Figures (a,b,c) visualize the covariance approximations q(x|x~=x+σϵ),ϵ𝒩(0,σ2I)similar-to𝑞conditional𝑥~𝑥𝑥𝜎italic-ϵitalic-ϵ𝒩0superscript𝜎2𝐼q(x|\tilde{x}=x+\sigma\epsilon),\epsilon\sim\mathcal{N}(0,\sigma^{2}I)italic_q ( italic_x | over~ start_ARG italic_x end_ARG = italic_x + italic_σ italic_ϵ ) , italic_ϵ ∼ caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) on 25 x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG samples. We use a sigmoid function to map the real value noise into grayscale pixels for the visualization.

As discussed in Section 1.1, the KL divergence is not well-defined for manifold data distributions.

Refer to caption
(a) KL loss
Refer to caption
(b) DSM loss
Figure 5: Training loss comparison of two objectives. We plot the training loss every iteration during a total 300 epochs.

This limitation becomes evident when working with MNIST, where the presence of constant black pixels in the boundary areas leads to a rapid decrease towards 0 in the variance of qθ(x|x~)subscript𝑞𝜃conditional𝑥~𝑥q_{\theta}(x|\tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) during training. Consequently, the likelihood value logqθ(x|x~)subscript𝑞𝜃conditional𝑥~𝑥\log q_{\theta}(x|\tilde{x})roman_log italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) tends to approach infinity, resulting in unstable training. In contrast, the DSM objective is well-defined for manifold data, providing a stable training process even in the presence of such boundary effects. Figure 5 provides a visual comparison of the two training procedures, demonstrating the improved stability and effectiveness of the DSM objective in handling manifold data distributions.

For the sample generation process, calculating the full-covariance Gaussian posterior becomes challenging. We therefore apply the scalable diagonal Hessian approximation described in Section 3 to approximate the diagonal Gaussian covariance of p(x|x~)𝑝conditional𝑥~𝑥p(x|\tilde{x})italic_p ( italic_x | over~ start_ARG italic_x end_ARG ). We find that the estimated diagonal Hessian occasionally contains small negative values due to approximation error; we, therefore, use the max(,ϵ)italic-ϵ\max(\cdot,{\epsilon})roman_max ( ⋅ , italic_ϵ ) function with ϵ>0italic-ϵ0{\epsilon}>0italic_ϵ > 0 to ensure the positivity of the diagonal covariance. The x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent isotropic covariance and the proposed x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent diagonal covariance share the same mean function.

Refer to caption
(a) Data x𝑥xitalic_x
Refer to caption
(b) 1111 sample
Refer to caption
(c) 10101010 samples
Refer to caption
(d) 100100100100 samples
Figure 6: Visualizations of the diagonal covariance q(x|x~=x+σϵ)𝑞conditional𝑥~𝑥𝑥𝜎italic-ϵq(x|\tilde{x}=x+\sigma\epsilon)italic_q ( italic_x | over~ start_ARG italic_x end_ARG = italic_x + italic_σ italic_ϵ ) with different number of Rademacher samples.
Refer to caption
Refer to caption
Refer to caption
Figure 7: Samples from three Markov chains. We plot the samples every 10 Gibbs steps.

We first visualize the covariance estimated by three different methods in Figure 4. We use 100 Rademacher samples in estimating the diagonal Hessian (Equation 18) and 50,000 samples in estimating the isotropic variance (Equation 17). We find that both x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-dependent diagonal covariance approximations can capture the posterior structure whereas the isotropic x~~𝑥\tilde{x}over~ start_ARG italic_x end_ARG-independent covariance is just Gaussian noise since the variance is shared between different digit and pixel locations. In Figure 3, we plot the sample comparison for three methods.

Since the isotropic covariance has the same variance in each dimension, the generated samples in Figure 2(b) contain white noise in the black background, whereas the proposed full-covariance sampler can generate a clean black background in Figure 2(a). On the other hand, the samples generated by the KL-trained Gibbs sampler (Figure 2(c)) have worse sample quality due to the unstable training.

We then apply the same method to model the more complicated CIFAR 10 Krizhevsky et al. (2009) dataset. We use the same U-Net structure as used in Song and Ermon (2020) and directly parameterize the score function rather than the energy function to speed up the training. The noise level is fixed at 0.30.30.30.3. We train the model using Adam optimizer with learning rate 1×1041superscript1041{\times}10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and batch size 100 for 1000 epochs. We visualize the denoising posterior diagonal covariance in Figure 6 when using different numbers of Rademacher samples (Equation 18). We observe that better covariance estimation can be obtained by increasing the number of samples. To balance efficiency and accuracy, we use a sample number of 10 in the subsequent Gibbs sampling stage. Figure 7 shows three independent Markov chains with the samples plotted every 10 Gibbs steps, which demonstrates that sharp images can be generated with even one fixed level of noise.

Refer to caption
Figure 8: FID evaluation with Increased Gibbs Steps. We can find the FID increases after 40 Gibbs steps.

Limitation: In the CIFAR experiment, we observe a mode collapse phenomenon when running multiple independent Markov chains for a longer time. This phenomenon is likely due to the small noise level σ=0.3𝜎0.3\sigma=0.3italic_σ = 0.3, which prevents the sampler from exploring the full space, as commonly found with MCMC methods Robert et al. (1999). This effect is visually represented in Figure 8, where we assess the Fréchet Inception Distance (FID) values for 50,000 images sampled with varying numbers of Gibbs steps. Notably, the FID increases beyond 40 Gibbs steps, and visual evidence of mode collapse is observed (Figure 12). In the ensuing section, we will demonstrate the application of our method to settings with multiple noise levels, an approach that may help mitigate the issue of mode collapse.

5 Image Generation Using Multiple Noise Levels

Algorithm 1 Sampling with Langevin Dynamics
{σt}t=0T,δ,Ksuperscriptsubscriptsubscript𝜎𝑡𝑡0𝑇𝛿𝐾\{\sigma_{t}\}_{t=0}^{T},\delta,K{ italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT , italic_δ , italic_K
Initialize xT0subscriptsuperscript𝑥0𝑇x^{0}_{T}italic_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT
for tT𝑡𝑇t\leftarrow Titalic_t ← italic_T to 1111  do
     αtδσt2/σ02subscript𝛼𝑡𝛿superscriptsubscript𝜎𝑡2superscriptsubscript𝜎02\alpha_{t}\leftarrow\delta\sigma_{t}^{2}/\sigma_{0}^{2}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_δ italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
     for k1𝑘1k\leftarrow 1italic_k ← 1 to K𝐾Kitalic_K do
         Draw zk𝒩(0,I)similar-tosuperscript𝑧𝑘𝒩0𝐼z^{k}\sim\mathcal{N}(0,I)italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∼ caligraphic_N ( 0 , italic_I )
         xtkxtk1+αtsθ(xtk1,σt)+2αtzksubscriptsuperscript𝑥𝑘𝑡subscriptsuperscript𝑥𝑘1𝑡subscript𝛼𝑡subscript𝑠𝜃superscriptsubscript𝑥𝑡𝑘1subscript𝜎𝑡2subscript𝛼𝑡superscript𝑧𝑘x^{k}_{t}{\leftarrow}x^{k-1}_{t}+\alpha_{t}s_{\theta}(x_{t}^{k-1},\sigma_{t})+% \sqrt{2\alpha_{t}}z^{k}italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_x start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + square-root start_ARG 2 italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT
     end for
     xt10xtKsuperscriptsubscript𝑥𝑡10superscriptsubscript𝑥𝑡𝐾x_{t-1}^{0}\leftarrow x_{t}^{K}italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ← italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT
end for
return x00+σ02sθ(x00,σ0)subscriptsuperscript𝑥00subscriptsuperscript𝜎20subscript𝑠𝜃subscriptsuperscript𝑥00subscript𝜎0x^{0}_{0}+\sigma^{2}_{0}s_{\theta}(x^{0}_{0},\sigma_{0})italic_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
Algorithm 2 Sampling with the proposed pseudo Gibbs Sampling
{σt}t=0T,Ksuperscriptsubscriptsubscript𝜎𝑡𝑡0𝑇𝐾\{\sigma_{t}\}_{t=0}^{T},K{ italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT , italic_K
Initialize xT0subscriptsuperscript𝑥0𝑇x^{0}_{T}italic_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT
for tT𝑡𝑇t\leftarrow Titalic_t ← italic_T to 1111  do
     for k1𝑘1k\leftarrow 1italic_k ← 1 to K𝐾Kitalic_K do
         Draw xt+1kp(xt+1|xt=xtk)similar-tosubscriptsuperscript𝑥𝑘𝑡1𝑝conditionalsubscript𝑥𝑡1subscript𝑥𝑡subscriptsuperscript𝑥𝑘𝑡x^{k}_{t+1}\sim p(x_{t+1}|x_{t}=x^{k}_{t})italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
         Draw xtkqθ(xt|xt+1=xt+1k)similar-tosubscriptsuperscript𝑥𝑘𝑡subscript𝑞𝜃conditionalsubscript𝑥𝑡subscript𝑥𝑡1subscriptsuperscript𝑥𝑘𝑡1x^{k}_{t}\sim q_{\theta}(x_{t}|x_{t+1}=x^{k}_{t+1})italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT )
     end for
     xt10xtKsuperscriptsubscript𝑥𝑡10superscriptsubscript𝑥𝑡𝐾x_{t-1}^{0}\leftarrow x_{t}^{K}italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ← italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT
end for
return x00+σ02sθ(x00,σ0)subscriptsuperscript𝑥00subscriptsuperscript𝜎20subscript𝑠𝜃subscriptsuperscript𝑥00subscript𝜎0x^{0}_{0}+\sigma^{2}_{0}s_{\theta}(x^{0}_{0},\sigma_{0})italic_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )

The success of diffusion models and lessons from prior work on score-based generative models point to the importance of using multiple noise levels Ho et al. (2020); Song and Ermon (2019) when modelling data with complex multi-modal distributions. Intuitively, by learning to denoise data at a range of noise levels, a single network can learn both the fine and global structure of the distribution, which in turn allows for more effective sampling algorithms capable of efficiently exploring diverse modes (Song and Ermon, 2019). We therefore propose to adapt the denoising Gibbs sampling procedure to sample from distributions corrupted with multiple noise levels. For this purpose, we use a noise-conditioned score network trained by Song and Ermon (2020), who generated high-quality samples using a procedure inspired by annealed Langevin dynamics Welling and Teh (2011). This procedure involves generating samples from a sequence of distributions pT(xT),,p0(x0)subscript𝑝𝑇subscript𝑥𝑇subscript𝑝0subscript𝑥0p_{T}(x_{T}),...,p_{0}(x_{0})italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) , … , italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), corrupted by progressively decreasing levels of Gaussian noise (parameterized via standard deviations σtsubscript𝜎𝑡\sigma_{t}italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, with σT>,,>σ0\sigma_{T}>,\cdots,>\sigma_{0}italic_σ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT > , ⋯ , > italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT). At a given step t𝑡titalic_t in the sequence, Langevin dynamics is used to sample from the corresponding noised distribution pt(xt)subscript𝑝𝑡subscript𝑥𝑡p_{t}(x_{t})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), using the score network sθ(xt,σt)subscript𝑠𝜃subscript𝑥𝑡subscript𝜎𝑡s_{\theta}(x_{t},\sigma_{t})italic_s start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) to approximate the gradient of the noised distribution. The outputs of this Langevin dynamics run are then used to initialize the same procedure at the next noise level, leading the sampling procedure to converge gradually towards the data distribution as the noise level tends to zero (i.e. p0(x0)pd(x),σ00formulae-sequencesubscript𝑝0subscript𝑥0subscript𝑝𝑑𝑥subscript𝜎00p_{0}(x_{0})\approx p_{d}(x),\sigma_{0}\approx 0italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≈ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) , italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≈ 0). The algorithm of the annealed Langevin dynamics with multi-level noise used in Song and Ermon (2019, 2020) is summarized in Algorithm 1.

We show that the proposed Gibbs sampling scheme can be directly applied to a pre-trained score-based generative model as a drop-in replacement for Langevin dynamics MCMC in the generation stage. At each noise level, we use samples xt+1subscript𝑥𝑡1x_{t+1}italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT from the previous noise level to initialise a Gibbs sampling chain targeting the marginal distribution at the current noise level pt(xt)subscript𝑝𝑡subscript𝑥𝑡p_{t}(x_{t})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), which now plays the role of the ‘clean’ distribution in Equation 5. Therefore, the noisy distribution at time step t𝑡titalic_t is a Gaussian p(xt+1|xt)=𝒩(xt,σt+12σt2)𝑝conditionalsubscript𝑥𝑡1subscript𝑥𝑡𝒩subscript𝑥𝑡superscriptsubscript𝜎𝑡12superscriptsubscript𝜎𝑡2p(x_{t+1}|x_{t})=\operatorname{\mathcal{N}}(x_{t},\sigma_{t+1}^{2}-\sigma_{t}^% {2})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). The optimal denoising distribution q(xt|xt+1)𝑞conditionalsubscript𝑥𝑡subscript𝑥𝑡1q(x_{t}|x_{t+1})italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) is thus a function of the level of noise at step t𝑡titalic_t relative to the level of noise at the previous step t+1𝑡1t+1italic_t + 1. For generation efficiency, we employ 3 Gibbs steps at each noise level, using 3 Rademacher samples to approximate the diagonal Hessian. The sampling procedure is summarized in Algorithm 2. In Figure 9 and 10, we visualize the samples from models that are trained on CIFAR10 and CelebA separately. Further experimental details can be found in Appendix C.2.

Table 2: CIFAR10 Inception and FID Scores
Model Inception FID
NCSNv2 (Langevin  (Song and Ermon, 2020)) 8.40±0.07plus-or-minus8.400.07\mathbf{8.40\pm 0.07}bold_8.40 ± bold_0.07 10.8710.8710.8710.87
NCSNv2 (Gibbs, Ours) 8.28±0.07plus-or-minus8.280.078.28\pm 0.078.28 ± 0.07 10.7510.75\mathbf{10.75}bold_10.75
DDPM (Ho et al., 2020) 9.46±0.11plus-or-minus9.460.119.46\pm 0.119.46 ± 0.11 3.17
NCSN++ (Song et al., 2021) 9.899.899.899.89 2.202.202.202.20

For direct comparison with the results of (Song and Ermon, 2020) on CIFAR10, we retain the same schedule of noise levels used to generate samples with Langevin dynamics. We generate 50000 samples using this approach and report FID and Inception scores in Table 2. Our multi-level Gibbs sampling scheme produces samples of equivalent quality to the multi-level Langevin dynamics of (Song and Ermon, 2020), confirming its applicability to complex natural image data. The FID is also notably superior to that of the single-noise level Gibbs sampling, and the samples exhibit significant visual diversity (Figure 9). This underlines the importance of employing multi-level noise in our approach. Recent advances in sampling strategies for score-based models leveraging the framework of stochastic differential equations (Song et al., 2021) have led to significant further improvements in generation quality as shown in Table 2; we leave the exploration of possible applications of our method to this framework to future work.

[Uncaptioned image]
Figure 9: CIFAR 10 Samples
[Uncaptioned image]
Figure 10: CelebA Samples

6 Conclusion

This paper focuses on addressing the inconsistency problem in training energy-based models (EBMs) using denoising score matching. Specifically, we identify the presence of an underlying clean model within a ‘noisy’ EBM and propose an efficient sampling scheme for the clean model. We demonstrate how this method can be effectively applied to high-dimensional data and showcase image generation results in both single and multi-level noise settings. More broadly, we hope our more accurate denoising posterior opens new avenues for future work on score-based methods in machine learning.

References

  • Arjovsky et al. (2017) M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein generative adversarial networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
  • Bao et al. (2022) F. Bao, C. Li, J. Zhu, and B. Zhang. Analytic-dpm: an analytic estimate of the optimal reverse variance in diffusion probabilistic models. arXiv preprint arXiv:2201.06503, 2022.
  • Bekas et al. (2007) C. Bekas, E. Kokiopoulou, and Y. Saad. An estimator for the diagonal of a matrix. Applied numerical mathematics, 57(11-12):1214–1229, 2007.
  • Bengio et al. (2013) Y. Bengio, L. Yao, G. Alain, and P. Vincent. Generalized denoising auto-encoders as generative models. Advances in neural information processing systems, 26, 2013.
  • Bishop and Nasrabadi (2006) C. M. Bishop and N. M. Nasrabadi. Pattern recognition and machine learning, volume 4. Springer, 2006.
  • Bose et al. (2018) A. J. Bose, H. Ling, and Y. Cao. Adversarial contrastive estimation. arXiv preprint arXiv:1805.03642, 2018.
  • Du and Mordatch (2019) Y. Du and I. Mordatch. Implicit generation and modeling with energy based models. Advances in Neural Information Processing Systems, 32, 2019.
  • Du et al. (2020) Y. Du, S. Li, J. Tenenbaum, and I. Mordatch. Improved contrastive divergence training of energy based models. arXiv preprint arXiv:2012.01316, 2020.
  • Efron (2011) B. Efron. Tweedie’s formula and selection bias. Journal of the American Statistical Association, 106(496):1602–1614, 2011.
  • Fisher (1925) R. A. Fisher. Theory of statistical estimation. In Mathematical proceedings of the Cambridge philosophical society, volume 22, pages 700–725. Cambridge University Press, 1925.
  • Gao et al. (2018) R. Gao, Y. Lu, J. Zhou, S.-C. Zhu, and Y. N. Wu. Learning generative convnets via multi-grid modeling and sampling. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 9155–9164, 2018.
  • Gretton et al. (2012) A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. Smola. A kernel two-sample test. The Journal of Machine Learning Research, 13(1):723–773, 2012.
  • Hinton (2002) G. E. Hinton. Training products of experts by minimizing contrastive divergence. Neural computation, 14(8):1771–1800, 2002.
  • Ho et al. (2020) J. Ho, A. Jain, and P. Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.
  • Hutchinson (1990) M. F. Hutchinson. A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communications in Statistics-Simulation and Computation, 19(2):433–450, 1990.
  • Hyvärinen (2005) A. Hyvärinen. Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research, 6(4), 2005.
  • Jolicoeur-Martineau et al. (2021) A. Jolicoeur-Martineau, R. Piché-Taillefer, I. Mitliagkas, and R. T. des Combes. Adversarial score matching and improved sampling for image generation. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=eLfqMl3z3lq.
  • Kim and Bengio (2016) T. Kim and Y. Bengio. Deep directed generative models with energy-based probability estimation. arXiv preprint arXiv:1606.03439, 2016.
  • Kingma and Ba (2014) D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Krizhevsky et al. (2009) A. Krizhevsky, G. Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • LeCun (1998) Y. LeCun. The mnist database of handwritten digits. http://yann. lecun. com/exdb/mnist/, 1998.
  • Martens et al. (2012) J. Martens, I. Sutskever, and K. Swersky. Estimating the hessian by back-propagating curvature. arXiv preprint arXiv:1206.6464, 2012.
  • Meng et al. (2021) C. Meng, Y. Song, W. Li, and S. Ermon. Estimating high order gradients of the data distribution by denoising. Advances in Neural Information Processing Systems, 34:25359–25369, 2021.
  • Minka (2013) T. P. Minka. Expectation propagation for approximate bayesian inference. arXiv preprint arXiv:1301.2294, 2013.
  • Ngiam et al. (2011) J. Ngiam, Z. Chen, P. W. Koh, and A. Y. Ng. Learning deep energy models. In Proceedings of the 28th international conference on machine learning (ICML-11), pages 1105–1112, 2011.
  • Nijkamp et al. (2019) E. Nijkamp, M. Hill, S.-C. Zhu, and Y. N. Wu. Learning non-convergent non-persistent short-run mcmc toward energy-based model. Advances in Neural Information Processing Systems, 32, 2019.
  • Paszke et al. (2017) A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, and A. Lerer. Automatic differentiation in pytorch. 2017.
  • Ramachandran et al. (2017) P. Ramachandran, B. Zoph, and Q. V. Le. Searching for activation functions. arXiv preprint arXiv:1710.05941, 2017.
  • Robbins (1992) H. E. Robbins. An empirical bayes approach to statistics. In Breakthroughs in Statistics: Foundations and basic theory, pages 388–394. Springer, 1992.
  • Robert et al. (1999) C. P. Robert, G. Casella, and G. Casella. Monte Carlo statistical methods, volume 2. Springer, 1999.
  • Ronneberger et al. (2015) O. Ronneberger, P. Fischer, and T. Brox. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015.
  • Salimans and Ho (2021) T. Salimans and J. Ho. Should ebms model the energy or the score? In Energy Based Models Workshop-ICLR 2021, 2021.
  • Saremi (2019) S. Saremi. On approximating f𝑓\nabla f∇ italic_f with neural networks. arXiv preprint arXiv:1910.12744, 2019.
  • Saremi et al. (2018) S. Saremi, A. Mehrjou, B. Schölkopf, and A. Hyvärinen. Deep energy estimator networks. arXiv preprint arXiv:1805.08306, 2018.
  • Song and Ermon (2019) Y. Song and S. Ermon. Generative modeling by estimating gradients of the data distribution. Advances in Neural Information Processing Systems, 32, 2019.
  • Song and Ermon (2020) Y. Song and S. Ermon. Improved techniques for training score-based generative models. In H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020. URL https://proceedings.neurips.cc/paper/2020/hash/92c3b916311a5517d9290576e3ea37ad-Abstract.html.
  • Song and Kingma (2021) Y. Song and D. P. Kingma. How to train your energy-based models. arXiv preprint arXiv:2101.03288, 2021.
  • Song et al. (2020) Y. Song, S. Garg, J. Shi, and S. Ermon. Sliced score matching: A scalable approach to density and score estimation. In Uncertainty in Artificial Intelligence, pages 574–584. PMLR, 2020.
  • Song et al. (2021) Y. Song, J. Sohl-Dickstein, D. P. Kingma, A. Kumar, S. Ermon, and B. Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=PxTIG12RRHS.
  • Tao (2011) T. Tao. An introduction to measure theory, volume 126. American Mathematical Society Providence, 2011.
  • Vincent (2011) P. Vincent. A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674, 2011.
  • Wang et al. (2020) Z. Wang, S. Cheng, L. Yueru, J. Zhu, and B. Zhang. A wasserstein minimum velocity approach to learning unnormalized models. In International Conference on Artificial Intelligence and Statistics, pages 3728–3738. PMLR, 2020.
  • Welling and Teh (2011) M. Welling and Y. W. Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML-11), pages 681–688, 2011.
  • Wendland (2004) H. Wendland. Scattered data approximation, volume 17. Cambridge university press, 2004.
  • Wenliang and Kanagawa (2020) L. Wenliang and H. Kanagawa. Blindness of score-based methods to isolated components and mixing proportions. arXiv preprint arXiv:2008.10087, 2020.
  • Xie et al. (2016) J. Xie, Y. Lu, S.-C. Zhu, and Y. Wu. A theory of generative convnet. In International Conference on Machine Learning, pages 2635–2644. PMLR, 2016.
  • Zhai et al. (2016) S. Zhai, Y. Cheng, R. Feris, and Z. Zhang. Generative adversarial networks as variational training of energy based models. arXiv preprint arXiv:1611.01799, 2016.
  • Zhang et al. (2019) M. Zhang, T. Bird, R. Habib, T. Xu, and D. Barber. Variational f-divergence minimization. arXiv preprint arXiv:1907.11891, 2019.
  • Zhang et al. (2020) M. Zhang, P. Hayes, T. Bird, R. Habib, and D. Barber. Spread divergence. In International Conference on Machine Learning, pages 11106–11116. PMLR, 2020.
  • Zhang et al. (2022a) M. Zhang, P. Hayes, and D. Barber. Generalization gap in amortized inference. Advances in neural information processing systems, 2022a.
  • Zhang et al. (2022b) M. Zhang, O. Key, P. Hayes, D. Barber, B. Paige, and F.-X. Briol. Towards healing the blindness of score matching. arXiv preprint arXiv:2209.07396, 2022b.

Appendix A Proof and Derivations

A.1 Proof of Theorem 2.1

The existence is straightforward, since FD(p~d||q~θ*)=0p~d=q~θ*{\mathrm{FD}}(\tilde{p}_{d}||\tilde{q}_{\theta^{*}})=0\rightarrow\tilde{p}_{d}% =\tilde{q}_{\theta^{*}}roman_FD ( over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) = 0 → over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, we can simply let q(x)=pd(x)𝑞𝑥subscript𝑝𝑑𝑥q(x)=p_{d}(x)italic_q ( italic_x ) = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ), which makes q(x)p(x~|x)dx=pd(x)p(x~|x)dx=p~d𝑞𝑥𝑝conditional~𝑥𝑥differential-d𝑥subscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥differential-d𝑥subscript~𝑝𝑑\int q(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}=\int p_{d}(x)p(\tilde{x}|x)% \mathop{}\!\mathrm{d}{x}=\tilde{p}_{d}∫ italic_q ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x = ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x = over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. To show the uniqueness, we denote density k(ϵ)=𝒩(0,σ2I)𝑘italic-ϵ𝒩0superscript𝜎2𝐼k({\epsilon})=\operatorname{\mathcal{N}}(0,\sigma^{2}I)italic_k ( italic_ϵ ) = caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ), so q~θ(x~)subscript~𝑞𝜃~𝑥\tilde{q}_{\theta}(\tilde{x})over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) and p~d(x)subscript~𝑝𝑑𝑥\tilde{p}_{d}(x)over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) can be written as convolutions

q~θ(x~)=q*k,p~d(x~)=pd*k,formulae-sequencesubscript~𝑞𝜃~𝑥𝑞𝑘subscript~𝑝𝑑~𝑥subscript𝑝𝑑𝑘\displaystyle\tilde{q}_{\theta}(\tilde{x})=q*k,\quad\tilde{p}_{d}(\tilde{x})=p% _{d}*k,over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_q * italic_k , over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT * italic_k , (19)

we then have

p~d=q~θsubscript~𝑝𝑑subscript~𝑞𝜃\displaystyle\tilde{p}_{d}=\tilde{q}_{\theta}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT q*k=pd*k(q)(k)=(pd)(k),\displaystyle\Leftrightarrow q*k=p_{d}*k\Leftrightarrow\mathcal{F}(q)\mathcal{% F}(k)=\mathcal{F}(p_{d})\mathcal{F}(k),⇔ italic_q * italic_k = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT * italic_k ⇔ caligraphic_F ( italic_q ) caligraphic_F ( italic_k ) = caligraphic_F ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) caligraphic_F ( italic_k ) , (20)

where \mathcal{F}caligraphic_F denotes the Fourier transform. Since the Fourier transform of a Gaussian is also a Gaussian, so (k)>0𝑘0\mathcal{F}(k)>0caligraphic_F ( italic_k ) > 0 everywhere, we have

p~d=q~θ*subscript~𝑝𝑑subscript~𝑞superscript𝜃\displaystyle\tilde{p}_{d}=\tilde{q}_{\theta^{*}}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (q)(k)=(pd)(k)(q)=(pd)q=pd.\displaystyle\Leftrightarrow\mathcal{F}(q)\cancel{\mathcal{F}(k)}=\mathcal{F}(% p_{d})\cancel{\mathcal{F}(k)}\Leftrightarrow\mathcal{F}(q)=\mathcal{F}(p_{d})% \Leftrightarrow q=p_{d}.⇔ caligraphic_F ( italic_q ) cancel caligraphic_F ( italic_k ) = caligraphic_F ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) cancel caligraphic_F ( italic_k ) ⇔ caligraphic_F ( italic_q ) = caligraphic_F ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) ⇔ italic_q = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT . (21)

Therefore, q=pd𝑞subscript𝑝𝑑q=p_{d}italic_q = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is the unique distribution that makes p~d=q~θsubscript~𝑝𝑑subscript~𝑞𝜃\tilde{p}_{d}=\tilde{q}_{\theta}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. This technique has also been used to construct spread KL divergence (we denote as KL~~KL\widetilde{{\mathrm{KL}}}over~ start_ARG roman_KL end_ARG)  Zhang et al. [2020], which is defined as KL~(pd||qθ)KL(pd*k||qθ*k)\widetilde{{\mathrm{KL}}}(p_{d}||q_{\theta})\equiv\mathrm{KL}(p_{d}*k||q_{% \theta}*k)over~ start_ARG roman_KL end_ARG ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ≡ roman_KL ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT * italic_k | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT * italic_k ) where k(ϵ)=N(0,σ2I)𝑘italic-ϵ𝑁0superscript𝜎2𝐼k(\epsilon)=N(0,\sigma^{2}I)italic_k ( italic_ϵ ) = italic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ), to train implicit model qθsubscript𝑞𝜃q_{\theta}italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. Different from the DSM situation, when KL~(pd||qθ)=0\widetilde{{\mathrm{KL}}}(p_{d}||q_{\theta})=0over~ start_ARG roman_KL end_ARG ( italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) = 0, the underlying model qθ=pdsubscript𝑞𝜃subscript𝑝𝑑q_{\theta}=p_{d}italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is directly available, whereas the EBM q~θsubscript~𝑞𝜃\tilde{q}_{\theta}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT trained by DSM learns to be the noisy distribution q~θ=pd*ksubscript~𝑞𝜃subscript𝑝𝑑𝑘\tilde{q}_{\theta}=p_{d}*kover~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT * italic_k.

A.2 General Conditions Characterising the Existence of the Clean Model

In the previous section, we assume for a flexible neural network parameterized fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, the energy-based model q~θ(x~)=exp(f(x~))/Z(θ)subscript~𝑞𝜃~𝑥𝑓~𝑥𝑍𝜃\tilde{q}_{\theta}(\tilde{x})=\exp(-f(\tilde{x}))/Z(\theta)over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = roman_exp ( - italic_f ( over~ start_ARG italic_x end_ARG ) ) / italic_Z ( italic_θ ) trained by Equation 4 can recover the target noisy data distribution q~θ*=p~dsubscript~𝑞superscript𝜃subscript~𝑝𝑑\tilde{q}_{\theta^{*}}=\tilde{p}_{d}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT so there exists an underlying model q𝑞qitalic_q such that q~θ*=q*ksubscript~𝑞superscript𝜃𝑞𝑘\tilde{q}_{\theta^{*}}=q*kover~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_q * italic_k and q=pd𝑞subscript𝑝𝑑q=p_{d}italic_q = italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. This assumption is commonly used in the literature on score-based methods. For example, in the score-based diffusion models literature Song and Ermon [2019], Ho et al. [2020], Bao et al. [2022], for any data xD𝑥superscript𝐷x\in{\mathbb{R}}^{D}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, the score function x~logq~θ(x~)subscript~𝑥subscript~𝑞𝜃~𝑥\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) is usually parameterized by a neural network NNθ():DD:subscriptNN𝜃superscript𝐷superscript𝐷\mathrm{NN}_{\theta}(\cdot):{\mathbb{R}}^{D}\rightarrow{\mathbb{R}}^{D}roman_NN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. However, this parameterization cannot guarantee NNθ(x~)subscriptNN𝜃~𝑥\mathrm{NN}_{\theta}(\tilde{x})roman_NN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) is a conservative vector field, or in other words, there doesn’t exist a distribution q~θ(x~)subscript~𝑞𝜃~𝑥\tilde{q}_{\theta}(\tilde{x})over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) such that x~q~θ(x~)=x~logq~θ(x~)subscript~𝑥subscript~𝑞𝜃~𝑥subscript~𝑥subscript~𝑞𝜃~𝑥\nabla_{\tilde{x}}\tilde{q}_{\theta}(\tilde{x})=\nabla_{\tilde{x}}\log\tilde{q% }_{\theta}(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) and x~2logq~(x~)subscriptsuperscript2~𝑥~𝑞~𝑥\nabla^{2}_{\tilde{x}}\log\tilde{q}(\tilde{x})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG ( over~ start_ARG italic_x end_ARG ) is symmetric Salimans and Ho [2021], Saremi [2019]. Therefore, perfect score estimation x~logp~d(x~)=x~logq~θ(x~)subscript~𝑥subscript~𝑝𝑑~𝑥subscript~𝑥subscript~𝑞𝜃~𝑥\nabla_{\tilde{x}}\log\tilde{p}_{d}(\tilde{x})=\nabla_{\tilde{x}}\log\tilde{q}% _{\theta}(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) is implicitly assumed to allow an EBM interpretation.

However, the underlying clean model doesn’t always exist for imperfect model q~θp~dsubscript~𝑞𝜃subscript~𝑝𝑑\tilde{q}_{\theta}\neq\tilde{p}_{d}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≠ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. We here provide the sufficient and necessary conditions which guarantee the existence of the underlying clean model.

Theorem A.1 (Necessary and Sufficient conditions for the existence of the underlying clean model.).

For a model q~θsubscriptnormal-~𝑞𝜃\tilde{q}_{\theta}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT with the convolutional noise distribution k(ϵ)=𝒩(0,σ2I)𝑘italic-ϵ𝒩0superscript𝜎2𝐼k(\epsilon)=\operatorname{\mathcal{N}}(0,\sigma^{2}I)italic_k ( italic_ϵ ) = caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ), there exists an underlying model q𝑞qitalic_q such that q*k=q~𝑞𝑘normal-~𝑞q*k=\tilde{q}italic_q * italic_k = over~ start_ARG italic_q end_ARG if and only if (q~θ)/(k)subscriptnormal-~𝑞𝜃𝑘\mathcal{F}(\tilde{q}_{\theta})/\mathcal{F}(k)caligraphic_F ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) / caligraphic_F ( italic_k ) is positive semi-definite 666 A continuous function f:dnormal-:𝑓normal-→superscript𝑑f:{\mathbb{R}}^{d}\rightarrow\mathbb{C}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_C is positive semi-definite if for all n𝑛n\in\mathbb{N}italic_n ∈ blackboard_N, all sets of pairwise distinct centers X={x1,,xN}d𝑋subscript𝑥1normal-…subscript𝑥𝑁superscript𝑑X=\{x_{1},...,x_{N}\}\in{\mathbb{R}}^{d}italic_X = { italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and all αN𝛼superscript𝑁\alpha\in\mathbb{C}^{N}italic_α ∈ blackboard_C start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , i=1Nj=1Nαiαj¯f(xixj)0superscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁subscript𝛼𝑖normal-¯subscript𝛼𝑗𝑓subscript𝑥𝑖subscript𝑥𝑗0\sum_{i=1}^{N}\sum_{j=1}^{N}\alpha_{i}\overline{\alpha_{j}}f(x_{i}-x_{j})\geq 0∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG italic_α start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≥ 0, see [Wendland, 2004, Definition 6.1]. Additionally, the underlying distribution q𝑞qitalic_q can be written as

q=1((q~θ)/(k)),𝑞superscript1subscript~𝑞𝜃𝑘\displaystyle q=\mathcal{F}^{-1}(\mathcal{F}(\tilde{q}_{\theta})/\mathcal{F}(k% )),italic_q = caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( caligraphic_F ( over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) / caligraphic_F ( italic_k ) ) , (22)

where 1superscript1\mathcal{F}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is the inverse Fourier transform. This theorem is a straightforward corollary of Bochner’s Theorem 777 Bochner’s Theorem [Wendland, 2004, Theorem 6.6]: A continuous function f:d:𝑓superscript𝑑f:{\mathbb{R}}^{d}\rightarrow\mathbb{C}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_C is positive semi-definite if and only if it is the Fourier transform of a finite non-negative Borel measure on dsuperscript𝑑{\mathbb{R}}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. . However, for the energy model q~θ(x~)exp(fθ(x~))proportional-tosubscript~𝑞𝜃~𝑥subscript𝑓𝜃~𝑥\tilde{q}_{\theta}(\tilde{x})\propto\exp{(-f_{\theta}(\tilde{x}))}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ∝ roman_exp ( - italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ), it’s difficult to design a functioning family of f𝑓fitalic_f that satisfies the positive semi-definite condition and have the tractable score function at the same time 888For example, one can define a noisy energy-based model q~θ=exp(fθ(x~))/Z(θ)subscript~𝑞𝜃subscript𝑓𝜃~𝑥𝑍𝜃\tilde{q}_{\theta}=\exp(-f_{\theta}(\tilde{x}))/Z(\theta)over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = roman_exp ( - italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) / italic_Z ( italic_θ ) with fθ(x~)=log(gθ(x)1/σ2x~x22)dxsubscript𝑓𝜃~𝑥subscript𝑔𝜃𝑥1superscript𝜎2superscriptsubscriptnorm~𝑥𝑥22differential-d𝑥-f_{\theta}(\tilde{x})=\log\int(-g_{\theta}(x)-1/\sigma^{2}||\tilde{x}-x||_{2}% ^{2})\mathop{}\!\mathrm{d}{x}- italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = roman_log ∫ ( - italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - 1 / italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | | over~ start_ARG italic_x end_ARG - italic_x | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) roman_d italic_x, which always allows an underlying clean energy-based model qθ(x)=exp(gθ(x))/Z(θ)subscript𝑞𝜃𝑥subscript𝑔𝜃𝑥𝑍𝜃q_{\theta}(x)=\exp{(-g_{\theta}(x))}/Z(\theta)italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) = roman_exp ( - italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ) / italic_Z ( italic_θ ) such that q~θ(x~)=qθ(x)*ksubscript~𝑞𝜃~𝑥subscript𝑞𝜃𝑥𝑘\tilde{q}_{\theta}(\tilde{x})=q_{\theta}(x)*kover~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) * italic_k with k(ϵ)=𝒩(0,σ2I)𝑘italic-ϵ𝒩0superscript𝜎2𝐼k(\epsilon)=\operatorname{\mathcal{N}}(0,\sigma^{2}I)italic_k ( italic_ϵ ) = caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ). However, the score function x~logq~(x~)=x~fθ(x~)subscript~𝑥~𝑞~𝑥subscript~𝑥subscript𝑓𝜃~𝑥\nabla_{\tilde{x}}\log\tilde{q}(\tilde{x})=-\nabla_{\tilde{x}}f_{\theta}(% \tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG ( over~ start_ARG italic_x end_ARG ) = - ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) is intractable in this case.. We thus leave the design of better energy function parameterizations as a promising future direction.

A.3 Proof of Theorem 2.2

Derivation of the Mean Identity

We let q~θ(x~)=k(x~|x)qθ(x)dx~subscript~𝑞𝜃~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥differential-d~𝑥\tilde{q}_{\theta}(\tilde{x})=\int k(\tilde{x}|x)q_{\theta}(x)\mathop{}\!% \mathrm{d}{\tilde{x}}over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∫ italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) roman_d over~ start_ARG italic_x end_ARG, where k(x~|x)=𝒩(x,σ2I)𝑘conditional~𝑥𝑥𝒩𝑥superscript𝜎2𝐼k(\tilde{x}|x)=\mathcal{N}(x,\sigma^{2}I)italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) = caligraphic_N ( italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ), we have

x~logq~θ(x~)subscript~𝑥subscript~𝑞𝜃~𝑥\displaystyle\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) =x~q~θ(x~)q~θ(x~)=x~k(x~|x)qθ(x)dxq~θ(x~)absentsubscript~𝑥subscript~𝑞𝜃~𝑥subscript~𝑞𝜃~𝑥subscript~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥differential-d𝑥subscript~𝑞𝜃~𝑥\displaystyle=\frac{\nabla_{\tilde{x}}\tilde{q}_{\theta}(\tilde{x})}{\tilde{q}% _{\theta}(\tilde{x})}=\frac{\int\nabla_{\tilde{x}}k(\tilde{x}|x)q_{\theta}(x)% \mathop{}\!\mathrm{d}{x}}{\tilde{q}_{\theta}(\tilde{x})}= divide start_ARG ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG = divide start_ARG ∫ ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) roman_d italic_x end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG
=1σ2((x~x)k(x~|x)qθ(x)q~θ(x~))dxabsent1superscript𝜎2~𝑥𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥differential-d𝑥\displaystyle=-\frac{1}{\sigma^{2}}\int\left((\tilde{x}-x)\frac{k(\tilde{x}|x)% q_{\theta}(x)}{\tilde{q}_{\theta}(\tilde{x})}\right)\mathop{}\!\mathrm{d}{x}= - divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ ( ( over~ start_ARG italic_x end_ARG - italic_x ) divide start_ARG italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG ) roman_d italic_x
σ2x~logq~θ(x~)+x~absentsuperscript𝜎2subscript~𝑥subscript~𝑞𝜃~𝑥~𝑥\displaystyle\Longrightarrow\sigma^{2}\nabla_{\tilde{x}}\log\tilde{q}_{\theta}% (\tilde{x})+\tilde{x}⟹ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + over~ start_ARG italic_x end_ARG =xk(x~|x)qθ(x)q~θ(x~)dx=xqθ(x|x~)absent𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥differential-d𝑥subscriptdelimited-⟨⟩𝑥subscript𝑞𝜃conditional𝑥~𝑥\displaystyle=\int x\frac{k(\tilde{x}|x)q_{\theta}(x)}{\tilde{q}_{\theta}(% \tilde{x})}\mathop{}\!\mathrm{d}{x}=\langle x\rangle_{q_{\theta}(x|\tilde{x})}= ∫ italic_x divide start_ARG italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG roman_d italic_x = ⟨ italic_x ⟩ start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT

where we define the model denoising posterior using Bayes rule qθ(x|x~)k(x~|x)qθ(x)/q~θ(x~)subscript𝑞𝜃conditional𝑥~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥q_{\theta}(x|\tilde{x})\equiv k(\tilde{x}|x)q_{\theta}(x)/\tilde{q}_{\theta}(% \tilde{x})italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) ≡ italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) / over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ). The second equality is due to the following Gaussian distribution property

x~k(x~|x)=12πσ2x~e(x~x)22σ2=x~xσ212πσ2e(x~x)22σ2=x~xσ2k(x~|x).subscript~𝑥𝑘conditional~𝑥𝑥12𝜋superscript𝜎2subscript~𝑥superscript𝑒superscript~𝑥𝑥22superscript𝜎2~𝑥𝑥superscript𝜎212𝜋superscript𝜎2superscript𝑒superscript~𝑥𝑥22superscript𝜎2~𝑥𝑥superscript𝜎2𝑘conditional~𝑥𝑥\displaystyle\nabla_{\tilde{x}}k(\tilde{x}|x)=\frac{1}{\sqrt{2\pi\sigma^{2}}}% \nabla_{\tilde{x}}e^{\frac{-(\tilde{x}-x)^{2}}{2\sigma^{2}}}=-\frac{\tilde{x}-% x}{\sigma^{2}}\frac{1}{\sqrt{2\pi\sigma^{2}}}e^{\frac{-(\tilde{x}-x)^{2}}{2% \sigma^{2}}}=-\frac{\tilde{x}-x}{\sigma^{2}}k(\tilde{x}|x).∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT divide start_ARG - ( over~ start_ARG italic_x end_ARG - italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT = - divide start_ARG over~ start_ARG italic_x end_ARG - italic_x end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG italic_e start_POSTSUPERSCRIPT divide start_ARG - ( over~ start_ARG italic_x end_ARG - italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_POSTSUPERSCRIPT = - divide start_ARG over~ start_ARG italic_x end_ARG - italic_x end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) . (23)

Derivations of the Analytical Full Covariance Identity

We have derived the mean identity

μq(x~)xqθ(x|x~)=σ2x~logq~θ(x~)+x~.subscript𝜇𝑞~𝑥subscriptdelimited-⟨⟩𝑥subscript𝑞𝜃conditional𝑥~𝑥superscript𝜎2subscript~𝑥subscript~𝑞𝜃~𝑥~𝑥\displaystyle\mu_{q}(\tilde{x})\equiv\langle x\rangle_{q_{\theta}(x|\tilde{x})% }=\sigma^{2}\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})+\tilde{x}.italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ≡ ⟨ italic_x ⟩ start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + over~ start_ARG italic_x end_ARG . (24)

Taking the gradient over x𝑥xitalic_x in both side and scaling with σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, we have

σ2x~μq(x~)=σ4x~2logq~θ(x~)+σ2I.superscript𝜎2subscript~𝑥subscript𝜇𝑞~𝑥superscript𝜎4superscriptsubscript~𝑥2subscript~𝑞𝜃~𝑥superscript𝜎2𝐼\displaystyle\sigma^{2}\nabla_{\tilde{x}}\mu_{q}(\tilde{x})=\sigma^{4}\nabla_{% \tilde{x}}^{2}\log\tilde{q}_{\theta}(\tilde{x})+\sigma^{2}I.italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I . (25)

We can also expand the hessian of the logq~θ(x~)subscript~𝑞𝜃~𝑥\log\tilde{q}_{\theta}(\tilde{x})roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ):

x~2logq~θ(x~)subscriptsuperscript2~𝑥subscript~𝑞𝜃~𝑥\displaystyle\nabla^{2}_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) =1σ2x~((x~x)k(x~|x)qθ(x)q~θ(x~))dxabsent1superscript𝜎2subscript~𝑥~𝑥𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥d𝑥\displaystyle=-\frac{1}{\sigma^{2}}\int\nabla_{\tilde{x}}\left((\tilde{x}-x)% \frac{k(\tilde{x}|x)q_{\theta}(x)}{\tilde{q}_{\theta}(\tilde{x})}\right)% \mathop{}\!\mathrm{d}{x}= - divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT ( ( over~ start_ARG italic_x end_ARG - italic_x ) divide start_ARG italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG ) roman_d italic_x
=1σ2k(x~|x)qθ(x)q~θ(x~)dxabsent1superscript𝜎2𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥differential-d𝑥\displaystyle=-\frac{1}{\sigma^{2}}\int\frac{k(\tilde{x}|x)q_{\theta}(x)}{% \tilde{q}_{\theta}(\tilde{x})}\mathop{}\!\mathrm{d}{x}= - divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ divide start_ARG italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG roman_d italic_x +1σ2(x~x)x~k(x~|x)q~θ(x~)qθ(x)x~q~θ(x~)k(x~|x)qθ(x)q~θ2(x~)dx1superscript𝜎2~𝑥𝑥subscript~𝑥𝑘conditional~𝑥𝑥subscript~𝑞𝜃~𝑥subscript𝑞𝜃𝑥subscript~𝑥subscript~𝑞𝜃~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscriptsuperscript~𝑞2𝜃~𝑥differential-d𝑥\displaystyle+\frac{1}{\sigma^{2}}\int(\tilde{x}-x)\frac{\nabla_{\tilde{x}}k(% \tilde{x}|x)\tilde{q}_{\theta}(\tilde{x})q_{\theta}(x)-\nabla_{\tilde{x}}% \tilde{q}_{\theta}(\tilde{x})k(\tilde{x}|x)q_{\theta}(x)}{\tilde{q}^{2}_{% \theta}(\tilde{x})}\mathop{}\!\mathrm{d}{x}+ divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ ( over~ start_ARG italic_x end_ARG - italic_x ) divide start_ARG ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG roman_d italic_x
σ2x~2logq~θ(x~)+1absentsuperscript𝜎2superscriptsubscript~𝑥2subscript~𝑞𝜃~𝑥1\displaystyle\Longrightarrow\sigma^{2}\nabla_{\tilde{x}}^{2}\log\tilde{q}_{% \theta}(\tilde{x})+1⟹ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + 1 =(x~x)x~k(x~|x)qθ(x)x~logq~θ(x~)k(x~|x)qθ(x)q~θ(x~)dxabsent~𝑥𝑥subscript~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑥subscript~𝑞𝜃~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥differential-d𝑥\displaystyle=\int(\tilde{x}-x)\frac{\nabla_{\tilde{x}}k(\tilde{x}|x)q_{\theta% }(x)-\nabla_{\tilde{x}}\log\tilde{q}_{\theta}(\tilde{x})k(\tilde{x}|x)q_{% \theta}(x)}{\tilde{q}_{\theta}(\tilde{x})}\mathop{}\!\mathrm{d}{x}= ∫ ( over~ start_ARG italic_x end_ARG - italic_x ) divide start_ARG ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) - ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG roman_d italic_x
=(x~x)absent~𝑥𝑥\displaystyle=\int(\tilde{x}-x)= ∫ ( over~ start_ARG italic_x end_ARG - italic_x ) 1σ2(x~x)k(x~|x)qθ(x)+1σ2(x~xqθ(x|x~))k(x~|x)qθ(x)q~θ(x~)dx1superscript𝜎2~𝑥𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥1superscript𝜎2~𝑥subscriptdelimited-⟨⟩𝑥subscript𝑞𝜃conditional𝑥~𝑥𝑘conditional~𝑥𝑥subscript𝑞𝜃𝑥subscript~𝑞𝜃~𝑥d𝑥\displaystyle\frac{-\frac{1}{\sigma^{2}}(\tilde{x}-x)k(\tilde{x}|x)q_{\theta}(% x)+\frac{1}{\sigma^{2}}(\tilde{x}-\langle x\rangle_{q_{\theta}(x|\tilde{x})})k% (\tilde{x}|x)q_{\theta}(x)}{\tilde{q}_{\theta}(\tilde{x})}\mathop{}\!\mathrm{d% }{x}divide start_ARG - divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( over~ start_ARG italic_x end_ARG - italic_x ) italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) + divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( over~ start_ARG italic_x end_ARG - ⟨ italic_x ⟩ start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT ) italic_k ( over~ start_ARG italic_x end_ARG | italic_x ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_ARG roman_d italic_x
σ4x~2logq~θ(x~)+σ2Iabsentsuperscript𝜎4superscriptsubscript~𝑥2subscript~𝑞𝜃~𝑥superscript𝜎2𝐼\displaystyle\Longrightarrow\sigma^{4}\nabla_{\tilde{x}}^{2}\log\tilde{q}_{% \theta}(\tilde{x})+\sigma^{2}I⟹ italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log over~ start_ARG italic_q end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I =((x~x)2+(x~x)(x~xqθ(x|x~)))qθ(x|x~)dxabsentsuperscript~𝑥𝑥2~𝑥𝑥~𝑥subscriptdelimited-⟨⟩𝑥subscript𝑞𝜃conditional𝑥~𝑥subscript𝑞𝜃conditional𝑥~𝑥differential-d𝑥\displaystyle=\int\left(-(\tilde{x}-x)^{2}+(\tilde{x}-x)(\tilde{x}-\langle x% \rangle_{q_{\theta}(x|\tilde{x})})\right)q_{\theta}(x|\tilde{x})\mathop{}\!% \mathrm{d}{x}= ∫ ( - ( over~ start_ARG italic_x end_ARG - italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ( over~ start_ARG italic_x end_ARG - italic_x ) ( over~ start_ARG italic_x end_ARG - ⟨ italic_x ⟩ start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT ) ) italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) roman_d italic_x
=x2qθ(x|x~)xqθ(x|x~)2Σq(x~)absentsubscriptdelimited-⟨⟩superscript𝑥2subscript𝑞𝜃conditional𝑥~𝑥subscriptsuperscriptdelimited-⟨⟩𝑥2subscript𝑞𝜃conditional𝑥~𝑥subscriptΣ𝑞~𝑥\displaystyle=\langle x^{2}\rangle_{q_{\theta}(x|\tilde{x})}-\langle x\rangle^% {2}_{q_{\theta}(x|\tilde{x})}\equiv\Sigma_{q}(\tilde{x})= ⟨ italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT - ⟨ italic_x ⟩ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT ≡ roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG )

Therefore, we obtain the analytical full covariance identity.

Σq(x~)=σ2x~μq(x~).subscriptΣ𝑞~𝑥superscript𝜎2subscript~𝑥subscript𝜇𝑞~𝑥\displaystyle\Sigma_{q}(\tilde{x})=\sigma^{2}\nabla_{\tilde{x}}\mu_{q}(\tilde{% x}).roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over~ start_ARG italic_x end_ARG end_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) . (26)

A.4 Proof of Theorem 2.3

Lemma A.2 (KL to Gaussian Bao et al. [2022]).

Let p(x)𝑝𝑥p(x)italic_p ( italic_x ) be a distribution with mean μpsubscript𝜇𝑝\mu_{p}italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and covariance Σpsubscriptnormal-Σ𝑝\Sigma_{p}roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and q(x)=𝒩(μq,Σq)𝑞𝑥𝒩subscript𝜇𝑞subscriptnormal-Σ𝑞q(x)=\operatorname{\mathcal{N}}(\mu_{q},\Sigma_{q})italic_q ( italic_x ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ), denote the differential entropy as H(p)p(x)logp(x)dxnormal-H𝑝𝑝𝑥𝑝𝑥differential-d𝑥\mathrm{H}(p)\equiv-\int p(x)\log p(x)\mathop{}\!\mathrm{d}{x}roman_H ( italic_p ) ≡ - ∫ italic_p ( italic_x ) roman_log italic_p ( italic_x ) roman_d italic_x, we have

KL(p||q)=KL(𝒩(μp,Σp)||q)+H(𝒩(μp,Σp))H(p)\displaystyle{\mathrm{KL}}(p||q)={\mathrm{KL}}(\operatorname{\mathcal{N}}(\mu_% {p},\Sigma_{p})||q)+\mathrm{H}(\operatorname{\mathcal{N}}(\mu_{p},\Sigma_{p}))% -\mathrm{H}(p)roman_KL ( italic_p | | italic_q ) = roman_KL ( caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) | | italic_q ) + roman_H ( caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) ) - roman_H ( italic_p ) (27)

The proof can be found in Bao et al. [2022] Lemma 2.

We can then prove Theorem 2.3. Since p(x~|x)pd(x)=p(x|x~)p~d(x~)𝑝conditional~𝑥𝑥subscript𝑝𝑑𝑥𝑝conditional𝑥~𝑥subscript~𝑝𝑑~𝑥p(\tilde{x}|x)p_{d}(x)=p(x|\tilde{x})\tilde{p}_{d}(\tilde{x})italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) = italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ), where p~d(x~)=pd(x)p(x~|x)dxsubscript~𝑝𝑑~𝑥subscript𝑝𝑑𝑥𝑝conditional~𝑥𝑥differential-d𝑥\tilde{p}_{d}(\tilde{x})=\int p_{d}(x)p(\tilde{x}|x)\mathop{}\!\mathrm{d}{x}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = ∫ italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) roman_d italic_x, we have

KL(p(x~|x)pd(x)q(x|x~)p~d(x~))\displaystyle{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q(x|\tilde{x})\tilde{p% }_{d}(\tilde{x}))roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) =KL(p(x|x~)||q(x|x~))p~(x~)\displaystyle=\left\langle{\mathrm{KL}}(p(x|\tilde{x})||q(x|\tilde{x}))\right% \rangle_{\tilde{p}(\tilde{x})}= ⟨ roman_KL ( italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) | | italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT (28)

Assume Gaussian distribution q(x|x~)=𝒩(μq(x~),Σq(x~))𝑞conditional𝑥~𝑥𝒩subscript𝜇𝑞~𝑥subscriptΣ𝑞~𝑥q(x|\tilde{x})=\operatorname{\mathcal{N}}(\mu_{q}(\tilde{x}),\Sigma_{q}(\tilde% {x}))italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) )and denote the mean and covariance of the true posterior are μp(x~)subscript𝜇𝑝~𝑥\mu_{p}(\tilde{x})italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) and Σp(x~)subscriptΣ𝑝~𝑥\Sigma_{p}(\tilde{x})roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ), then the optimal q*superscript𝑞q^{*}italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT is

q*superscript𝑞\displaystyle q^{*}italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT =argminqKL(p(x~|x)pd(x)q(x|x~)p~d(x~))\displaystyle=\arg\min_{q}{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q(x|% \tilde{x})\tilde{p}_{d}(\tilde{x}))= roman_arg roman_min start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) (29)
=argminqKL(p(x|x~)||q(x|x~))p~(x~)\displaystyle=\arg\min_{q}\Big{\langle}{\mathrm{KL}}(p(x|\tilde{x})||q(x|% \tilde{x}))\Big{\rangle}_{\tilde{p}(\tilde{x})}= roman_arg roman_min start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ⟨ roman_KL ( italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) | | italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT (30)
=argminqKL(𝒩(μp,Σp)||q(x|x~))+H(𝒩(μp,Σp))H(p(x|x~))p~(x~)\displaystyle=\arg\min_{q}\Big{\langle}{\mathrm{KL}}(\operatorname{\mathcal{N}% }(\mu_{p},\Sigma_{p})||q(x|\tilde{x}))+\mathrm{H}(\operatorname{\mathcal{N}}(% \mu_{p},\Sigma_{p}))-\mathrm{H}(p(x|\tilde{x}))\Big{\rangle}_{\tilde{p}(\tilde% {x})}= roman_arg roman_min start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ⟨ roman_KL ( caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) | | italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) ) + roman_H ( caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) ) - roman_H ( italic_p ( italic_x | over~ start_ARG italic_x end_ARG ) ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT (31)
=argminqKL(𝒩(μp,Σp)||q(x|x~))p~(x~)+const..\displaystyle=\arg\min_{q}\Big{\langle}{\mathrm{KL}}(\operatorname{\mathcal{N}% }(\mu_{p},\Sigma_{p})||q(x|\tilde{x}))\Big{\rangle}_{\tilde{p}(\tilde{x})}+% const..= roman_arg roman_min start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ⟨ roman_KL ( caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) | | italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT + italic_c italic_o italic_n italic_s italic_t . . (32)

Therefore, the optimal q(x|x~)=𝒩(μq(x~),Σq(x~))𝑞conditional𝑥~𝑥𝒩subscript𝜇𝑞~𝑥subscriptΣ𝑞~𝑥q(x|\tilde{x})=\operatorname{\mathcal{N}}(\mu_{q}(\tilde{x}),\Sigma_{q}(\tilde% {x}))italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , roman_Σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) under the joint KL has the mean and covariance μq*(x~)=μp(x~),Σq*(x~))=Σp(x~)\mu^{*}_{q}(\tilde{x})=\mu_{p}(\tilde{x}),\Sigma^{*}_{q}(\tilde{x}))=\Sigma_{p% }(\tilde{x})italic_μ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) = italic_μ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , roman_Σ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) = roman_Σ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ).

Appendix B Connection to Analytical DDPM

Paper Bao et al. [2022] considers the constrained variational family qθ(x|x~)=𝒩(μθ(x~),σq2I)subscript𝑞𝜃conditional𝑥~𝑥𝒩subscript𝜇𝜃~𝑥superscriptsubscript𝜎𝑞2𝐼q_{\theta}(x|\tilde{x})=\operatorname{\mathcal{N}}(\mu_{\theta}(\tilde{x}),% \sigma_{q}^{2}I)italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | over~ start_ARG italic_x end_ARG ) = caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) , italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) and derive the optimal σq*superscriptsubscript𝜎𝑞\sigma_{q}^{*}italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT as

σq*2superscriptsubscript𝜎𝑞absent2\displaystyle\sigma_{q}^{*2}italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * 2 end_POSTSUPERSCRIPT =argminσqKL(p(x~|x)pd(x)qθ(x|y))p~d(x~))=1dTr(Covq(x|x~)[x])p~d(x~),\displaystyle=\arg\min_{\sigma_{q}}{\mathrm{KL}}(p(\tilde{x}|x)p_{d}(x)\lVert q% _{\theta}(x|y))\tilde{p}_{d}(\tilde{x}))=\frac{1}{d}\left\langle\mathrm{Tr}% \left(\mathrm{Cov}_{q(x|\tilde{x})}[x]\right)\right\rangle_{\tilde{p}_{d}(% \tilde{x})},= roman_arg roman_min start_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_KL ( italic_p ( over~ start_ARG italic_x end_ARG | italic_x ) italic_p start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_x ) ∥ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x | italic_y ) ) over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ) = divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ roman_Tr ( roman_Cov start_POSTSUBSCRIPT italic_q ( italic_x | over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT [ italic_x ] ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT , (33)

which can also be rewritten using the score function

σq*2=σ2σ4dsqθ(x~)22p~d(x~).superscriptsubscript𝜎𝑞absent2superscript𝜎2superscript𝜎4𝑑subscriptdelimited-⟨⟩superscriptsubscriptdelimited-∥∥subscript𝑠subscript𝑞𝜃~𝑥22subscript~𝑝𝑑~𝑥\displaystyle\sigma_{q}^{*2}=\sigma^{2}-\frac{\sigma^{4}}{d}\left\langle\left% \lVert s_{q_{\theta}}(\tilde{x})\right\rVert_{2}^{2}\right\rangle_{\tilde{p}_{% d}(\tilde{x})}.italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * 2 end_POSTSUPERSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG ⟨ ∥ italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT . (34)

To make a deep connection, we can also plug our analytical full covariance (Equation 10) into Equation 16

σq*2superscriptsubscript𝜎𝑞absent2\displaystyle\sigma_{q}^{*2}italic_σ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * 2 end_POSTSUPERSCRIPT =σ2+σ4dTrx2logqθ(x~)p~d(x~)absentsuperscript𝜎2superscript𝜎4𝑑Trsubscriptdelimited-⟨⟩superscriptsubscript𝑥2subscript𝑞𝜃~𝑥subscript~𝑝𝑑~𝑥\displaystyle=\sigma^{2}+\frac{\sigma^{4}}{d}\mathrm{Tr}\Big{\langle}\nabla_{x% }^{2}\log q_{\theta}(\tilde{x})\Big{\rangle}_{\tilde{p}_{d}(\tilde{x})}= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG roman_Tr ⟨ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT
=σ2σ4dTrsqθ(x~)sqθ(x~)Tp~d(x~)=σ2σ4dsqθ(x~)22p~d(x~),absentsuperscript𝜎2superscript𝜎4𝑑Trsubscriptdelimited-⟨⟩subscript𝑠subscript𝑞𝜃~𝑥subscript𝑠subscript𝑞𝜃superscript~𝑥𝑇subscript~𝑝𝑑~𝑥superscript𝜎2superscript𝜎4𝑑subscriptdelimited-⟨⟩superscriptsubscriptdelimited-∥∥subscript𝑠subscript𝑞𝜃~𝑥22subscript~𝑝𝑑~𝑥\displaystyle=\sigma^{2}-\frac{\sigma^{4}}{d}\mathrm{Tr}\left\langle s_{q_{% \theta}}(\tilde{x})s_{q_{\theta}}(\tilde{x})^{T}\right\rangle_{\tilde{p}_{d}(% \tilde{x})}=\sigma^{2}-\frac{\sigma^{4}}{d}\left\langle\left\lVert s_{q_{% \theta}}(\tilde{x})\right\rVert_{2}^{2}\right\rangle_{\tilde{p}_{d}(\tilde{x})},= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG roman_Tr ⟨ italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG ⟨ ∥ italic_s start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟩ start_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( over~ start_ARG italic_x end_ARG ) end_POSTSUBSCRIPT , (35)

which recovers Equation 17, where the first equality is due to the well-known Fisher information identity Fisher [1925].

Appendix C Experiments

All the experiments conducted in this paper are run on one single NVDIA GTX 3090.

C.1 Effect of the Single Noise Choice on MNIST

Figure 11 shows the samples generated by our method with the EBM trained with difference σ{0.3,0.5,0.8}𝜎0.30.50.8\sigma\in\{0.3,0.5,0.8\}italic_σ ∈ { 0.3 , 0.5 , 0.8 } in the noise distribution p(x~|x)𝑝conditional~𝑥𝑥p(\tilde{x}|x)italic_p ( over~ start_ARG italic_x end_ARG | italic_x ), we can find the image quality also heavily depends on the choice of the noise scale and σ=0.5𝜎0.5\sigma=0.5italic_σ = 0.5 achieves the best visual quality, we then use this hyper-parameter in the subsequent comparisons.

Refer to caption
(a) σ=0.3𝜎0.3\sigma=0.3italic_σ = 0.3
Refer to caption
(b) σ=0.5𝜎0.5\sigma=0.5italic_σ = 0.5
Refer to caption
(c) σ=0.8𝜎0.8\sigma=0.8italic_σ = 0.8
Figure 11: Sample comparisons with different σ𝜎\sigmaitalic_σ value.
Refer to caption
Figure 12: Mode Collapse visualization of 25 Markov chains, we plot the samples every 20 Gibbs steps, we can find less modes are covered if we run the Gibbs sampling for a longer time.

C.2 Multi-level Noise Details

For full details on the architecture and noise schedule used in the multi-level noise experiments in Section 5, we refer to Appendix B of [Song and Ermon, 2020]. For our multi-level Gibbs sampling procedure, we used 3 Gibbs steps at each noise level and 3 Rademacher samples for each diagonal Hessian computation. Following [Song and Ermon, 2020], we used a total of 232 noise levels, distributed according to their proposed geometric schedule, and applied a final denoising step in which the mean of the clean distribution conditioned on the final output of the sampling procedure is returned (the final output of the sampling procedure is a sample from the noised distribution from the noise distribution at the smallest noise level). This denoising step was previously found to improve FID scores [Jolicoeur-Martineau et al., 2021] significantly.