ScoreFusion: Fusing Score-based Generative Models via Kullback–Leibler Barycenters

Hao Liu Stanford University, CA 94305, US. Email: [email protected]    Junze (Tony) Ye Stanford University, CA 94305, US. Email: [email protected]    Jose Blanchet Stanford University, CA 94305, US. Email: [email protected]    Nian Si The University of Chicago, IL 60637, US. Email: [email protected]
(June 28, 2024)
Abstract

We study the problem of fusing pre-trained (auxiliary) generative models to enhance the training of a target generative model. We propose using KL-divergence weighted barycenters as an optimal fusion mechanism, in which the barycenter weights are optimally trained to minimize a suitable loss for the target population. While computing the optimal KL-barycenter weights can be challenging, we demonstrate that this process can be efficiently executed using diffusion score training when the auxiliary generative models are also trained based on diffusion score methods. Moreover, we show that our fusion method has a dimension-free sample complexity in total variation distance provided that the auxiliary models are well fitted for their own task and the auxiliary tasks combined capture the target well. The main takeaway of our method is that if the auxiliary models are well-trained and can borrow features from each other that are present in the target, our fusion method significantly improves the training of generative models. We provide a concise computational implementation of the fusion algorithm, and validate its efficiency in the low-data regime with numerical experiments involving mixtures models and image datasets.

1 Introduction

In recent advancements within the field of generative models, diffusion models [47, 24, 44] have emerged as have emerged as a potent framework for synthesizing high-quality and diverse outputs across diverse domains such as imagery, audio, and textual content [23, 4, 39, 37, 5, 57]. Successful commercial examples include DALL·E [38], Stable Diffusion [40], and Imagen [41]. The underlying mechanism of diffusion models involves a progressive addition of noise to a data sample until it approximates a Gaussian distribution, followed by a learned reverse process to reconstruct the original data by gradually denoising it.

Diffusion models rely on large datasets of high-dimensional data to accurately model the complex distributions needed for tasks like image generation and data augmentation [42, 55, 28, 43]. Without sufficient training data, diffusion models struggle to produce high-quality, diverse outputs and can overfit to the limited data they have been trained on [55, 59, 58].

Refer to caption
Figure 1: Quality of a diffusion model deteriorates noticeably as n𝑛nitalic_n, the training data size, decreases.

However, in practice, data scarcity can hinder the performance of generative models, especially in domains where data is limited due to high costs, privacy concerns, and proprietary restrictions by companies treating their data as a competitive advantage. These challenges mean that even as the demand for powerful generative models grows, the scarcity of usable data can significantly limit their development and effectiveness. To demonstrate this phenomenon, we show the generative performance of the digit with different data sample sizes in Figure 1. We observed that the quality of a diffusion model deteriorates noticeably as n𝑛nitalic_n, the training data size, decreases.

To address the issue of data scarcity, researchers and practitioners often utilize the idea of transfer learning [34, 56, 51, 60, 49]. Transfer learning is a technique in machine learning where a model developed for one task is reused as the starting point for a model on another task. This approach allows a model trained on large and common datasets to be adapted to a different, but related, problem or dataset with less data available. Many recent works develop transfer learning algorithms to finetune the diffusion models and achieve empirical success in areas such as image generation [52, 31, 33].

In this paper, we develop a fusion method for diffusion models. Specifically, our goal is to build a generative diffusion model for a target distribution where data availability is limited, using the assistance of multiple pretrained diffusion models. These pretrained models have been trained on several common datasets, allowing them to capture a broad range of features and patterns. By leveraging these pretrained models, we aim to enhance the performance of our diffusion model on the target distribution, despite the scarcity of training data. The difference is in transfer learning, parameters of diffusion models are retrained while in our method, we freeze neural network weights and create a new neural network with an extra linear layer.

Our method is based on fusing diffusion processes through the computation of an optimal barycenters. Given a set of weights, a barycenter is typically defined as a probability measure that minimizes the weighted sum of distances (or divergences) to a set of reference measures. The most common barycenter problem among distributions is the Wasserstein barycenter [1, 16, 36, 45, 13, 26]. However, computing Wasserstein barycenters is generally challenging [35, 7, 48, 22]. Therefore, we utilize a Kullback–Leibler (KL) barycenter [14, 6], which has an analytical solution given any section of weights (both when the reference measures are supported in Euclidean spaces or can be represented as diffusion processes). In turn, the weights of our KL barycenter formulation are optimized according to a suitable class of training losses, as we will explain in the sequel.

Our goal is to find the optimal weights to approximate the target dataset. We formulate two convex optimization problems, leading to two fusion methods. The first method is intuitive but requires an estimate of the reference densities and numerical integration, which is usually challenging in high-dimensional contexts. The second method is computationally cheaper since it becomes a linear regression problem after being embedded into the diffusion space, and it still achieves good theoretical and empirical performance.

The main contributions of our work are concluded as:

1) We demonstrate that KL barycenter fusion of auxiliary models can be efficiently implemented when the auxiliary models are trained based on score diffusion. In this case, the optimal score is linear in the auxiliary scores.

2) We provide generalization bounds which split the error into four components. First, the error between the optimal KL barycenter and the target at time zero (whose direct implementation is difficult due to numerical integration). The second term corresponds to the sample complexity O(n1/4)𝑂superscript𝑛14O(n^{-1/4})italic_O ( italic_n start_POSTSUPERSCRIPT - 1 / 4 end_POSTSUPERSCRIPT ) and the third term is the approximation error obtained by the diffusion embedding (which facilitates the training). The fourth component reflects the quality of auxiliary score estimations.

3) We numerically demonstrate the performance of our proposed fusion method. Specifically, we found that our method outperforms the basic diffusion method when the training sample size is small.

The rest of the paper is organized as follows. Section 2 reviews the background of KL barycenter and diffusion models. Section 3 details our proposed fusion methods. Section 4 provides convergence results for our methods. Section 5 presents numerical results. Finally, Section 6 concludes the paper with future directions. All proofs are relegated to the appendix.

2 Preliminaries and setup

2.1 Notations

The following notation will be used. Given two functions f,g:D:𝑓𝑔𝐷f,g:D\to\mathbb{R}italic_f , italic_g : italic_D → blackboard_R, we say fgless-than-or-similar-to𝑓𝑔f\lesssim gitalic_f ≲ italic_g if there exists a constant C>0𝐶0C>0italic_C > 0 such that for all xD𝑥𝐷x\in Ditalic_x ∈ italic_D, f(x)Cg(x)𝑓𝑥𝐶𝑔𝑥f(x)\leq Cg(x)italic_f ( italic_x ) ≤ italic_C italic_g ( italic_x ). When xa𝑥𝑎x\to aitalic_x → italic_a, where a[,]𝑎a\in[-\infty,\infty]italic_a ∈ [ - ∞ , ∞ ], we say f(x)=𝒪(g(x))𝑓𝑥𝒪𝑔𝑥f(x)=\mathcal{O}(g(x))italic_f ( italic_x ) = caligraphic_O ( italic_g ( italic_x ) ) if there exists a constant C>0𝐶0C>0italic_C > 0 such that for all x𝑥xitalic_x close enough to a𝑎aitalic_a, |f(x)|Mg(x)𝑓𝑥𝑀𝑔𝑥|f(x)|\leq Mg(x)| italic_f ( italic_x ) | ≤ italic_M italic_g ( italic_x ). In asymptotic cases, we use 𝒪𝒪\mathcal{O}caligraphic_O and less-than-or-similar-to\lesssim interchangeably. fgsimilar-to𝑓𝑔f\sim gitalic_f ∼ italic_g if and only if fgless-than-or-similar-to𝑓𝑔f\lesssim gitalic_f ≲ italic_g and gfless-than-or-similar-to𝑔𝑓g\lesssim fitalic_g ≲ italic_f. C([0,T]:d)C([0,T]:\mathbb{R}^{d})italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) is the space of all continuous functions on dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT equipped with the uniform topology. In this paper, we consider a Polish spaces S𝑆Sitalic_S, which could be dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT or C([0,T]:d)C([0,T]:\mathbb{R}^{d})italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ). For a Polish space S𝑆Sitalic_S equipped with Borel σ𝜎\sigmaitalic_σ-algebra (S)𝑆\mathcal{B}(S)caligraphic_B ( italic_S ), we denote 𝒫(S)𝒫𝑆\mathcal{P}(S)caligraphic_P ( italic_S ) as the space of probability measures on S𝑆Sitalic_S equipped with the topology of weak convergence. In a normed vector space (X,.)\left(X,\left\lVert.\right\rVert\right)( italic_X , ∥ . ∥ ), .\left\lVert.\right\rVert∥ . ∥ denotes the corresponding norm. .p\left\lVert.\right\rVert_{p}∥ . ∥ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT denotes the standard Lpsuperscript𝐿𝑝L^{p}italic_L start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT norm. Given a matrix 𝑨𝑨\boldsymbol{A}bold_italic_A, we use 𝑨Tsuperscript𝑨𝑇\boldsymbol{A}^{T}bold_italic_A start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT to denote its transpose. We denote 𝝀=(λ1,,λk)T[0,1]k𝝀superscriptsubscript𝜆1subscript𝜆𝑘𝑇superscript01𝑘\boldsymbol{\lambda}=\left(\lambda_{1},\ldots,\lambda_{k}\right)^{T}\in[0,1]^{k}bold_italic_λ = ( italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. We use ΔksubscriptΔ𝑘\Delta_{k}roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to present the k𝑘kitalic_k-dimensional probability simplex, i.e., Δk={𝝀[0,1]k:i=1kλi=1}subscriptΔ𝑘conditional-set𝝀superscript01𝑘superscriptsubscript𝑖1𝑘subscript𝜆𝑖1\Delta_{k}=\{\boldsymbol{\lambda}\in[0,1]^{k}:\sum_{i=1}^{k}\lambda_{i}=1\}roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = { bold_italic_λ ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT : ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 }.

2.2 Barycenter problems and Kullback–Leibler divergence

Given a set of probability measures P1,,Pk𝒫(S)subscript𝑃1subscript𝑃𝑘𝒫𝑆P_{1},\ldots,P_{k}\in\mathcal{P}(S)italic_P start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ caligraphic_P ( italic_S ) on a Polish space S𝑆Sitalic_S and a measure of dissimilarity (e.g. a metric or a divergence) between two elements in 𝒫(S)𝒫𝑆\mathcal{P}(S)caligraphic_P ( italic_S ), D𝐷Ditalic_D, we define the barycenter problem with respect to D𝐷Ditalic_D and weight 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ as the optimization problem

minμ𝒫(S)i=1kλiD(μ,Pi)s.t. 𝝀Δk,subscript𝜇𝒫𝑆superscriptsubscript𝑖1𝑘subscript𝜆𝑖𝐷𝜇subscript𝑃𝑖s.t. 𝝀subscriptΔ𝑘\min_{\mu\in\mathcal{P}(S)}\sum_{i=1}^{k}\lambda_{i}D\left(\mu,P_{i}\right)% \quad\text{s.t. }\boldsymbol{\lambda}\in\Delta_{k},roman_min start_POSTSUBSCRIPT italic_μ ∈ caligraphic_P ( italic_S ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D ( italic_μ , italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) s.t. bold_italic_λ ∈ roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ,

where P1,,Pksubscript𝑃1subscript𝑃𝑘P_{1},\ldots,P_{k}italic_P start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are called the reference measures. With a fixed choice of weight and reference measures, the solution of the barycenter problem is denoted as μ𝝀subscript𝜇𝝀\mu_{\boldsymbol{\lambda}}italic_μ start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT.

Recall the definition of Kullback–Leibler (KL) divergence: suppose P,Q𝒫(S)𝑃𝑄𝒫𝑆P,Q\in\mathcal{P}(S)italic_P , italic_Q ∈ caligraphic_P ( italic_S ), then DKL(PQ)=log(dPdQ)𝑑Psubscript𝐷KLconditional𝑃𝑄𝑑𝑃𝑑𝑄differential-d𝑃D_{\text{KL}}\left(P\parallel Q\right)=\int\log\left(\frac{dP}{dQ}\right)\,dPitalic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_P ∥ italic_Q ) = ∫ roman_log ( divide start_ARG italic_d italic_P end_ARG start_ARG italic_d italic_Q end_ARG ) italic_d italic_P if PQmuch-less-than𝑃𝑄P\ll Qitalic_P ≪ italic_Q and DKL(PQ)=subscript𝐷KLconditional𝑃𝑄D_{\text{KL}}\left(P\parallel Q\right)=\inftyitalic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_P ∥ italic_Q ) = ∞ otherwise; where dPdQ𝑑𝑃𝑑𝑄\frac{dP}{dQ}divide start_ARG italic_d italic_P end_ARG start_ARG italic_d italic_Q end_ARG is the Radon-Nikodym derivative of P𝑃Pitalic_P with respect to Q𝑄Qitalic_Q. In particular, if S=d𝑆superscript𝑑S=\mathbb{R}^{d}italic_S = blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, P𝑃Pitalic_P and Q𝑄Qitalic_Q are absolutely continuous random vectors (with respect to Lebesgue measure) in dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with densities p𝑝pitalic_p and q𝑞qitalic_q respectively, then DKL(PQ)=p(x)log(p(x)q(x))𝑑x.subscript𝐷KLconditional𝑃𝑄𝑝𝑥𝑝𝑥𝑞𝑥differential-d𝑥D_{\text{KL}}\left(P\parallel Q\right)=\int p(x)\log\left(\frac{p(x)}{q(x)}% \right)dx.italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_P ∥ italic_Q ) = ∫ italic_p ( italic_x ) roman_log ( divide start_ARG italic_p ( italic_x ) end_ARG start_ARG italic_q ( italic_x ) end_ARG ) italic_d italic_x . If D𝐷Ditalic_D is the KL divergence, we recover the KL barycenter problem [14]. In fact, for any Polish space S𝑆Sitalic_S, the KL barycenter problem is strictly convex hence has at most one solution.

2.3 Background on diffusion models

Our score fusion method depends the generative diffusion model driven by stochastic differential equations (SDEs) developed in Song et al. [47], Ho et al. [24], Sohl-Dickstein et al. [44]. In this section, we review the background of generative diffusion model.

2.3.1 Forward process: adding noise

We begin with the unsupervised learning setup. Given an unlabeled dataset i.i.d. from a distribution p0subscript𝑝0p_{0}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, the forward diffusion process is defined as the differential form

dX(t)=f(t,X(t))dt+g(t)dW(t),X(0)p0,formulae-sequence𝑑𝑋𝑡𝑓𝑡𝑋𝑡𝑑𝑡𝑔𝑡𝑑𝑊𝑡similar-to𝑋0subscript𝑝0dX(t)=f(t,X(t))dt+g(t)dW(t),X(0)\sim p_{0},italic_d italic_X ( italic_t ) = italic_f ( italic_t , italic_X ( italic_t ) ) italic_d italic_t + italic_g ( italic_t ) italic_d italic_W ( italic_t ) , italic_X ( 0 ) ∼ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , (1)

where f:+×dd:𝑓subscriptsuperscript𝑑superscript𝑑f:\mathbb{R}_{+}\times\mathbb{R}^{d}\to\mathbb{R}^{d}italic_f : blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a vector-valued function, g:+:𝑔subscriptg:\mathbb{R}_{+}\to\mathbb{R}italic_g : blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT → blackboard_R is a scalar function, and W(t)𝑊𝑡W(t)italic_W ( italic_t ) denotes a standard d𝑑ditalic_d-dimensional Brownian motion. From now on, we assume the existence and denote by pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) the marginal density function of X(t)𝑋𝑡X(t)italic_X ( italic_t ), and let pt|s(X(t)|X(s))subscript𝑝conditional𝑡𝑠conditional𝑋𝑡𝑋𝑠p_{t|s}\left(X(t)|X(s)\right)italic_p start_POSTSUBSCRIPT italic_t | italic_s end_POSTSUBSCRIPT ( italic_X ( italic_t ) | italic_X ( italic_s ) ) be the transition kernel from X(s)𝑋𝑠X(s)italic_X ( italic_s ) to X(t)𝑋𝑡X(t)italic_X ( italic_t ), for 0stT<0𝑠𝑡𝑇0\leq s\leq t\leq T<\infty0 ≤ italic_s ≤ italic_t ≤ italic_T < ∞, where T𝑇Titalic_T is the terminal time for the forward process (time horizon). If f(t,x)=ax𝑓𝑡𝑥𝑎𝑥f(t,x)=-axitalic_f ( italic_t , italic_x ) = - italic_a italic_x and g(t)=σ𝑔𝑡𝜎g(t)=\sigmaitalic_g ( italic_t ) = italic_σ with a>0𝑎0a>0italic_a > 0 and σ>0𝜎0\sigma>0italic_σ > 0, then Equation (1) becomes a linear SDE with Gaussian transition kernels

dX(t)=aX(t)dt+σdW(t),X(0)p0,formulae-sequence𝑑𝑋𝑡𝑎𝑋𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-to𝑋0subscript𝑝0dX(t)=-aX(t)dt+\sigma dW(t),X(0)\sim p_{0},italic_d italic_X ( italic_t ) = - italic_a italic_X ( italic_t ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_X ( 0 ) ∼ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , (2)

which is an Ornstein-Ulenback (OU) process. If T𝑇Titalic_T is large enough, then pTsubscript𝑝𝑇p_{T}italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is close to π𝒩(0,σ22aI)similar-to𝜋𝒩0superscript𝜎22𝑎I\pi\sim\mathcal{N}\left(\textbf{0},\frac{\sigma^{2}}{2a}\textbf{I}\right)italic_π ∼ caligraphic_N ( 0 , divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_a end_ARG I ), a Gaussian distribution with mean 0 (vector) and covariance matrix σ22aIsuperscript𝜎22𝑎I\frac{\sigma^{2}}{2a}\textbf{I}divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_a end_ARG I. The forward process can be viewed as the following dynamic: given the data distribution, we gradually add noise to it such that it becomes a known distribution in the long run.

2.3.2 Backward process: denoising

If we reverse a diffusion process in time, then under some mild conditions (see, for example, Cattiaux et al. [10], Föllmer [20]) which are satisfied for all processes under consideration in this work, we still get a diffusion process. To be more precise, we want to have a process X~~𝑋\tilde{X}over~ start_ARG italic_X end_ARG such that for t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ], X~(t)=X(Tt)~𝑋𝑡𝑋𝑇𝑡\tilde{X}(t)=X(T-t)over~ start_ARG italic_X end_ARG ( italic_t ) = italic_X ( italic_T - italic_t ). From the Fokker–Planck equation and the log trick [3], the corresponding reverse process for Process (1) is

dX~(t)=(f(Tt,X~(t))+g2(Tt)logpTt(X~(t)))dt+g(Tt)dW(t),X~(0)pT,formulae-sequence𝑑~𝑋𝑡𝑓𝑇𝑡~𝑋𝑡superscript𝑔2𝑇𝑡subscript𝑝𝑇𝑡~𝑋𝑡𝑑𝑡𝑔𝑇𝑡𝑑𝑊𝑡similar-to~𝑋0subscript𝑝𝑇d\tilde{X}(t)=\left(-f(T-t,\tilde{X}(t))+g^{2}(T-t)\nabla\log p_{T-t}\left(% \tilde{X}(t)\right)\right)dt+g(T-t)dW(t),\tilde{X}(0)\sim p_{T},italic_d over~ start_ARG italic_X end_ARG ( italic_t ) = ( - italic_f ( italic_T - italic_t , over~ start_ARG italic_X end_ARG ( italic_t ) ) + italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_T - italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over~ start_ARG italic_X end_ARG ( italic_t ) ) ) italic_d italic_t + italic_g ( italic_T - italic_t ) italic_d italic_W ( italic_t ) , over~ start_ARG italic_X end_ARG ( 0 ) ∼ italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , (3)

where \nabla represents taking derivative with respect to the space variable x𝑥xitalic_x. We call the term logpt(x)subscript𝑝𝑡𝑥\nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) as the (Stein) score function. If the forward process is an OU process, then the reverse process is

dX~(t)=(aX~(t)+σ2logpTt(X~(t)))dt+σdW(t),X~(0)pT.formulae-sequence𝑑~𝑋𝑡𝑎~𝑋𝑡superscript𝜎2subscript𝑝𝑇𝑡~𝑋𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-to~𝑋0subscript𝑝𝑇d\tilde{X}(t)=\left(a\tilde{X}(t)+\sigma^{2}\nabla\log p_{T-t}\left(\tilde{X}(% t)\right)\right)dt+\sigma dW(t),\tilde{X}(0)\sim p_{T}.italic_d over~ start_ARG italic_X end_ARG ( italic_t ) = ( italic_a over~ start_ARG italic_X end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over~ start_ARG italic_X end_ARG ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , over~ start_ARG italic_X end_ARG ( 0 ) ∼ italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT . (4)

If the backward SDE can be simulated (which is typically done via Euler–Maruyama method, see details in Appendix A.2), we can generate samples from the distribution p0subscript𝑝0p_{0}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We can view simulating the backward SDE as the denoising step from pure noise to the groundtruth distribution.

2.3.3 Score estimation

The only remaining task is score estimation for logpt(x)subscript𝑝𝑡𝑥\nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ). There are many ways to achieve this, and some of them are equivalent up to constants that is independent of the training parameters. In this paper, we choose the time-dependent score matching loss used in Song et al. [46]:

(θ;γ):=𝔼t𝒰[0,T][γ(t)𝔼X(t)pt[st,θ(X(t))logpt(X(t))22]],assign𝜃𝛾subscript𝔼similar-to𝑡𝒰0𝑇delimited-[]𝛾𝑡subscript𝔼similar-to𝑋𝑡subscript𝑝𝑡delimited-[]superscriptsubscriptdelimited-∥∥subscript𝑠𝑡𝜃𝑋𝑡subscript𝑝𝑡𝑋𝑡22\mathcal{L}\left(\theta;\gamma\right):=\mathbb{E}_{t\sim\mathcal{U}[0,T]}\left% [\gamma(t)\mathbb{E}_{X(t)\sim p_{t}}\left[\left\lVert s_{t,\theta}\left(X(t)% \right)-\nabla\log p_{t}(X(t))\right\rVert_{2}^{2}\right]\right],caligraphic_L ( italic_θ ; italic_γ ) := blackboard_E start_POSTSUBSCRIPT italic_t ∼ caligraphic_U [ 0 , italic_T ] end_POSTSUBSCRIPT [ italic_γ ( italic_t ) blackboard_E start_POSTSUBSCRIPT italic_X ( italic_t ) ∼ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT ( italic_X ( italic_t ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_X ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ] , (5)

where γ:[0,T]+:𝛾0𝑇subscript\gamma:[0,T]\to\mathbb{R}_{+}italic_γ : [ 0 , italic_T ] → blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT is a weighting function, and st,θ:dd:subscript𝑠𝑡𝜃superscript𝑑superscript𝑑s_{t,\theta}:\mathbb{R}^{d}\to\mathbb{R}^{d}italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a score estimator st,θsubscript𝑠𝑡𝜃s_{t,\theta}italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT, usually chosen as a neural network. Then score estimation is done by the empirical loss using SGD [29].

There are many ways to measure the goodness of the generative model. Suppose D(.,.)D(.,.)italic_D ( . , . ) is a measure of dissimilarity in 𝒫(S)𝒫𝑆\mathcal{P}(S)caligraphic_P ( italic_S ), then we say D(μ,μ^)𝐷𝜇^𝜇D(\mu,\hat{\mu})italic_D ( italic_μ , over^ start_ARG italic_μ end_ARG ) is a generalization error with respect to D𝐷Ditalic_D, where μ𝜇\muitalic_μ is the target distribution and μ^^𝜇\hat{\mu}over^ start_ARG italic_μ end_ARG is the distribution of the generated samples.

Recently, several analysis about the generative properties of diffusion models has been done; however, even in the case of compactly supported target distributions and sufficient smoothness regularity, the basic diffusion model encounters the curse of dimensionality. Therefore, a large amount of target data is needed to generate high quality samples. For a detailed discussion, see Appendix A.3.

3 KL barycenters and fusion methods

In Section 3.1, we propose and analytically solve two types of KL barycenter problems. These solutions will lead to the development of our fusion methods, which is detailed in Section 3.2.

3.1 KL barycenter problems

Theorem 1.

Suppose {μ1,,μk}𝒫(d)subscript𝜇1subscript𝜇𝑘𝒫superscript𝑑\{\mu_{1},\ldots,\mu_{k}\}\subset\mathcal{P}(\mathbb{R}^{d}){ italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } ⊂ caligraphic_P ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) and for each i=1,,k,𝑖1𝑘i=1,\ldots,k,italic_i = 1 , … , italic_k , μisubscript𝜇𝑖\mu_{i}italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is absolutely continuous with respect to the Lebesgue measure, with densities p1,,pksubscript𝑝1subscript𝑝𝑘p_{1},\ldots,p_{k}italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT respectively. Then, the distribution-lelvel KL barycenter μ𝛌subscript𝜇𝛌\mu_{\boldsymbol{\lambda}}italic_μ start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT is unique with density p𝛌(x)=i=1kpi(x)λidi=1kpi(x)λidx.subscript𝑝𝛌𝑥superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖subscriptsuperscript𝑑superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖𝑑𝑥p_{\boldsymbol{\lambda}}(x)=\frac{\prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}}{\int_% {\mathbb{R}^{d}}\prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}dx}.italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_x ) = divide start_ARG ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_x end_ARG .

Our second barycenter problem is performed when the sample space is the continuous-function space, i.e., S=C([0,T]:d)S=C([0,T]:\mathbb{R}^{d})italic_S = italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ). This context yields a process-level KL barycenter. When the underlying measures are represented by SDEs, we offer a closed-form solution for the process-level KL barycenter in Theorem 2.

Theorem 2.

Suppose for each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k, the i𝑖iitalic_i-th SDE has the form

dXi(t)=[c(t,Xi(t))+σ(t)2ai(t,X(t))]dt+σ(t)dWi(t),Xi(0)μi,formulae-sequence𝑑subscript𝑋𝑖𝑡delimited-[]𝑐𝑡subscript𝑋𝑖𝑡𝜎superscript𝑡2subscript𝑎𝑖𝑡𝑋𝑡𝑑𝑡𝜎𝑡𝑑subscript𝑊𝑖𝑡similar-tosubscript𝑋𝑖0subscript𝜇𝑖dX_{i}(t)=\left[c\left(t,X_{i}(t)\right)+\sigma(t)^{2}a_{i}\left(t,X(t)\right)% \right]dt+\sigma(t)dW_{i}(t),X_{i}(0)\sim\mu_{i},italic_d italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) = [ italic_c ( italic_t , italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) + italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_X ( italic_t ) ) ] italic_d italic_t + italic_σ ( italic_t ) italic_d italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) , italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 ) ∼ italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,

and has a unique strong solution. The law of solution of each SDE is denoted as Pi𝒫(C([0,T]:d)P_{i}\in\mathcal{P}(C([0,T]:\mathbb{R}^{d})italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_P ( italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ). We further assume, for each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k, μisubscript𝜇𝑖\mu_{i}italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT has an absolutely continuous density with respect to the Lebegue measure and aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT uniformly bounded, then process-level KL barycenter can be represented as the SDE

dX(t)=[c(t,X(t))+σ(t)2a(t,X(t))]dt+σ(t)dW(t),X(0)μ,formulae-sequence𝑑𝑋𝑡delimited-[]𝑐𝑡𝑋𝑡𝜎superscript𝑡2𝑎𝑡𝑋𝑡𝑑𝑡𝜎𝑡𝑑𝑊𝑡similar-to𝑋0𝜇dX(t)=\left[c\left(t,X(t)\right)+\sigma(t)^{2}a\left(t,X(t)\right)\right]dt+% \sigma(t)dW(t),X(0)\sim\mu,italic_d italic_X ( italic_t ) = [ italic_c ( italic_t , italic_X ( italic_t ) ) + italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a ( italic_t , italic_X ( italic_t ) ) ] italic_d italic_t + italic_σ ( italic_t ) italic_d italic_W ( italic_t ) , italic_X ( 0 ) ∼ italic_μ ,

where a(t,x)=i=1kλiai(t,x)𝑎𝑡𝑥superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝑎𝑖𝑡𝑥a(t,x)=\sum_{i=1}^{k}\lambda_{i}a_{i}(t,x)italic_a ( italic_t , italic_x ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_x ), μ𝜇\muitalic_μ is the distribution-level KL barycenter of reference measures μ1,,μksubscript𝜇1subscript𝜇𝑘\mu_{1},\ldots,\mu_{k}italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, and W𝑊Witalic_W is a standard Brownian motion.

In this paper, fusing k𝑘kitalic_k distributions is viewed as computing a KL barycenter with optimized weights. This naturally connects to the idea of transfer learning. Given k𝑘kitalic_k well-trained reference generative models, Our fusing method optimizes the weights λ1,,λksubscript𝜆1subscript𝜆𝑘\lambda_{1},\ldots,\lambda_{k}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to approximate a target distribution.

3.2 Fusion methods

Recall that in our task, we are given k𝑘kitalic_k datasets with abundant samples, and our goal is to generate samples for a target dataset with limited available data. Therefore, in this section, we denote the target measure as ν𝜈\nuitalic_ν and we assume that we are given k𝑘kitalic_k reference diffusion generative models and they are able to generate samples from k𝑘kitalic_k different reference measures μ1,,μksubscript𝜇1subscript𝜇𝑘\mu_{1},\ldots,\mu_{k}italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, respectively. Specifically, each reference measure corresponds to an auxiliary backward diffusion process

dX~i(t)=(aX~i(t)+σ2sTt,θi(X~i(t)))dt+σdWi(t),X~i(0)pTi,formulae-sequence𝑑subscript~𝑋𝑖𝑡𝑎subscript~𝑋𝑖𝑡superscript𝜎2subscriptsuperscript𝑠𝑖𝑇𝑡superscript𝜃subscript~𝑋𝑖𝑡𝑑𝑡𝜎𝑑subscript𝑊𝑖𝑡similar-tosubscript~𝑋𝑖0subscriptsuperscript𝑝𝑖𝑇d\tilde{X}_{i}(t)=\left(a\tilde{X}_{i}(t)+\sigma^{2}s^{i}_{T-t,\theta^{*}}% \left(\tilde{X}_{i}(t)\right)\right)dt+\sigma dW_{i}(t),\tilde{X}_{i}(0)\sim p% ^{i}_{T},italic_d over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) = ( italic_a over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) , over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 ) ∼ italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , (6)

where sTt,θisubscriptsuperscript𝑠𝑖𝑇𝑡superscript𝜃s^{i}_{T-t,\theta^{*}}italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is a well-trained score function for the the i𝑖iitalic_i-th reference measure. we introduce two fusion algorithms and related generalization error results.

In practice, the discretized version of the SDE (6) is used. Specifically, we employ a small time-discretization step hhitalic_h nd a total of N𝑁Nitalic_N time steps (hence T=Nh𝑇𝑁T=Nhitalic_T = italic_N italic_h). Since pTisuperscriptsubscript𝑝𝑇𝑖p_{T}^{i}italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT is close to the Gaussian distribution π𝜋\piitalic_π, the SDE (6) is approximated by X^(0)πsimilar-to^𝑋0𝜋\hat{X}(0)\sim\piover^ start_ARG italic_X end_ARG ( 0 ) ∼ italic_π and

dX^i(t)=(aX^i(t)+σ2sTlh,θi(X^i(lh)))dt+σdW(t),t[lh,(l+1)h].formulae-sequence𝑑subscript^𝑋𝑖𝑡𝑎subscript^𝑋𝑖𝑡superscript𝜎2subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃subscript^𝑋𝑖𝑙𝑑𝑡𝜎𝑑𝑊𝑡𝑡𝑙𝑙1d\hat{X}_{i}(t)=\left(a\hat{X}_{i}(t)+\sigma^{2}s^{i}_{T-lh,\theta^{*}}\left(% \hat{X}_{i}(lh)\right)\right)dt+\sigma dW(t),t\in[lh,(l+1)h].italic_d over^ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) = ( italic_a over^ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_l italic_h ) ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ] . (7)

Then, given a weight 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ, Theorem 2 implies that the corresponding process-level KL barycenter follows the SDE:

dY^(t)=(aY^(t)+σ2i=1kλisTlh,θi(Y^(lh)))dt+σdW(t),t[lh,(l+1)h],Y^(0)π.formulae-sequence𝑑^𝑌𝑡𝑎^𝑌𝑡superscript𝜎2superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃^𝑌𝑙𝑑𝑡𝜎𝑑𝑊𝑡formulae-sequence𝑡𝑙𝑙1similar-to^𝑌0𝜋d\hat{Y}(t)=\left(a\hat{Y}(t)+\sigma^{2}\sum_{i=1}^{k}\lambda_{i}s^{i}_{T-lh,% \theta^{*}}\left(\hat{Y}(lh)\right)\right)dt+\sigma dW(t),t\in[lh,(l+1)h],\hat% {Y}(0)\sim\pi.italic_d over^ start_ARG italic_Y end_ARG ( italic_t ) = ( italic_a over^ start_ARG italic_Y end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_Y end_ARG ( italic_l italic_h ) ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ] , over^ start_ARG italic_Y end_ARG ( 0 ) ∼ italic_π . (8)

We denote the distribution of the terminal variable Y^(T)^𝑌𝑇\hat{Y}(T)over^ start_ARG italic_Y end_ARG ( italic_T ) as p^𝝀subscript^𝑝𝝀\hat{p}_{\boldsymbol{\lambda}}over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT, which will later serve as the distribution of generated sample.

The key component in our diffusion method is to find an optimal 𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT such that the p^𝝀subscript^𝑝superscript𝝀\hat{p}_{\boldsymbol{\lambda}^{*}}over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is as close as the target measure ν𝜈\nuitalic_ν as possible. To achieve this goal, we propose two fusion methods that relies on two different optimization problems.

The first method directly optimizes on the probability measure defined on the Euclidean space, which is based on Theorem 1. Namely, we consider the following convex problem

min𝝀ΔkDKL(νμ𝝀)=min𝝀Δk𝔼ν[logq(X)i=1kλilogpi(X)]+log(i=1kpi(y)λidy),formulae-sequencesubscript𝝀subscriptΔ𝑘subscript𝐷KLconditional𝜈subscript𝜇𝝀subscript𝝀subscriptΔ𝑘subscript𝔼𝜈delimited-[]𝑞𝑋superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝑝𝑖𝑋superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑦subscript𝜆𝑖𝑑𝑦\min_{\boldsymbol{\lambda}\in\Delta_{k}}\quad D_{\text{KL}}(\nu\parallel\mu_{% \boldsymbol{\lambda}})=\min_{\boldsymbol{\lambda}\in\Delta_{k}}\quad\mathbb{E}% _{\nu}\left[\log q(X)-\sum_{i=1}^{k}\lambda_{i}\log p_{i}(X)\right]+\log\left(% \int\prod_{i=1}^{k}p_{i}(y)^{\lambda_{i}}dy\right),roman_min start_POSTSUBSCRIPT bold_italic_λ ∈ roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_μ start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ) = roman_min start_POSTSUBSCRIPT bold_italic_λ ∈ roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT [ roman_log italic_q ( italic_X ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_X ) ] + roman_log ( ∫ ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_y ) , (9)

where p1,,pksubscript𝑝1subscript𝑝𝑘p_{1},\ldots,p_{k}italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denote the densities of the reference measures and q(x)𝑞𝑥q(x)italic_q ( italic_x ) denote the density of target distribution ν𝜈\nuitalic_ν. We refer to this fusion method as vanilla fusion. Suppose we have an accurate estimation of the densities pisubscript𝑝𝑖p_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTs. We then use Frank-Wolfe method to solve Problem (9) and get an optimal 𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. In the Frank-Wolfe method, the gradient term can be approximated by sample mean estimators from target data ν𝜈\nuitalic_ν (See Remark 2 in Appendix C.1.3). To generate samples, we plug in the 𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to (8) and simulate the SDE.

We notice that a similar idea of fusing component distributions via KL barycenter compared with vanilla fusion has been proposed in Claici et al. [14], which uses averaging KL divergence as a metric to recover the mean-field approximation of posterior distribution of the fused global model. Both methods solves a two-layer optimziation problem: finding the barycenter and the optimal weight. Moreover, both methods introduce a convex optimization problem to help find optimizers. The difference is that vanilla fusion solves the barycenter problem first (since we almost know the analytical barycenter) and the main task is to find optimal weights, while Claici et al. [14] finds both optimizers simultaneously and their convex problem is only a relaxation of the original target.

However, the diffusion generative models usually cannot directly estimate the densities p1,,pksubscript𝑝1subscript𝑝𝑘p_{1},\ldots,p_{k}italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Therefore, for complicated high-dimension distributions, it is usually hard to directly apply vanilla fusion. Therefore, we propose a practical alternative, process-level method called ScoreFusion. The numerical results in Section 5 were generated by employing Algorithm 1.

In our second method, we first build a forward process starting from the target dataset, according to (2). We denote this forward process as X~ν(t)superscript~𝑋𝜈𝑡\tilde{X}^{\nu}(t)over~ start_ARG italic_X end_ARG start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT ( italic_t ) and the corresponding density as ptν(x)superscriptsubscript𝑝𝑡𝜈𝑥p_{t}^{\nu}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT ( italic_x ). Then, we modify the loss function (5) as a linear regression problem

~(𝝀;θ,γ)=𝔼t𝒰[0,T~][γ(t)(𝔼X(t)ptν[i=1k(λist,θi(X(t)))logptν(X(t))22])]~𝝀superscript𝜃𝛾subscript𝔼similar-to𝑡𝒰0~𝑇delimited-[]𝛾𝑡subscript𝔼similar-to𝑋𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscriptsuperscript𝑠𝑖𝑡superscript𝜃𝑋𝑡subscriptsuperscript𝑝𝜈𝑡𝑋𝑡22\displaystyle\tilde{\mathcal{L}}\left(\boldsymbol{\lambda};\theta^{*},\gamma% \right)=\mathbb{E}_{t\sim\mathcal{U}[0,\tilde{T}]}\left[\gamma(t)\left(\mathbb% {E}_{X(t)\sim p^{\nu}_{t}}\left[\left\lVert\sum_{i=1}^{k}\left(\lambda_{i}s^{i% }_{t,\theta^{*}}\left(X(t)\right)\right)-\nabla\log p^{\nu}_{t}(X(t))\right% \rVert_{2}^{2}\right]\right)\right]over~ start_ARG caligraphic_L end_ARG ( bold_italic_λ ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_γ ) = blackboard_E start_POSTSUBSCRIPT italic_t ∼ caligraphic_U [ 0 , over~ start_ARG italic_T end_ARG ] end_POSTSUBSCRIPT [ italic_γ ( italic_t ) ( blackboard_E start_POSTSUBSCRIPT italic_X ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X ( italic_t ) ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_X ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ) ] (10)

where we choose T~Tmuch-less-than~𝑇𝑇\tilde{T}\ll Tover~ start_ARG italic_T end_ARG ≪ italic_T. The intuition behind the choice of T~~𝑇\tilde{T}over~ start_ARG italic_T end_ARG is that we want to learn an optimal 𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT such that p𝝀subscript𝑝superscript𝝀p_{\boldsymbol{\lambda}^{*}}italic_p start_POSTSUBSCRIPT bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is close to the target ν𝜈\nuitalic_ν. Therefore, when T~Tmuch-less-than~𝑇𝑇\tilde{T}\ll Tover~ start_ARG italic_T end_ARG ≪ italic_T (the forward process has not inject much noise), the 𝝀^^𝝀\hat{\boldsymbol{\lambda}}over^ start_ARG bold_italic_λ end_ARG obtained from the training is affected less by the noise. Theoretically, choosing T~=0~𝑇0\tilde{T}=0over~ start_ARG italic_T end_ARG = 0 is optimal, but this is hard to implement. Algorithm 1 with T~=0~𝑇0\tilde{T}=0over~ start_ARG italic_T end_ARG = 0 can be viewed as a variant of vanilla fusion since the learning is only performed on the distribution level (p0subscript𝑝0p_{0}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT), and extremely small T~~𝑇\tilde{T}over~ start_ARG italic_T end_ARG causes numerical instability in practice, which makes sense given the numerical integration and density estimations needed in the vanilla fusion. The optimization problem associated with our second method is min𝝀Δk~(𝝀;θ,γ).subscript𝝀subscriptΔ𝑘~𝝀superscript𝜃𝛾\min_{\boldsymbol{\lambda}\in\Delta_{k}}\tilde{\mathcal{L}}\left(\boldsymbol{% \lambda};\theta^{*},\gamma\right).roman_min start_POSTSUBSCRIPT bold_italic_λ ∈ roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L end_ARG ( bold_italic_λ ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_γ ) . The details are in Algorithm 1.

Algorithm 1 ScoreFusion
1:Input: Training data 𝒟𝒟\mathcal{D}caligraphic_D, pre-trained score functions st,θ1,,st,θksubscriptsuperscript𝑠1𝑡superscript𝜃subscriptsuperscript𝑠𝑘𝑡superscript𝜃s^{1}_{t,\theta^{*}},\ldots,s^{k}_{t,\theta^{*}}italic_s start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , … , italic_s start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Hyperparameter T~~𝑇\tilde{T}over~ start_ARG italic_T end_ARG.
2:Output: Samples from a distribution ν^Dsubscript^𝜈𝐷\hat{\nu}_{D}over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT.generated
3:I. Training Phase
4:Randomly initialize non-negative λ1,,λksubscript𝜆1subscript𝜆𝑘\lambda_{1},\ldots,\lambda_{k}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT s.t. λi=1subscript𝜆𝑖1\sum\lambda_{i}=1∑ italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1.
5:repeat
6:     Run forward process X~ν(t)superscript~𝑋𝜈𝑡\tilde{X}^{\nu}(t)over~ start_ARG italic_X end_ARG start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT ( italic_t ) using a mini-batch from 𝒟𝒟\mathcal{D}caligraphic_D.
7:     Evaluate the loss function (10) and back-propagate onto λ1,,λksubscript𝜆1subscript𝜆𝑘\lambda_{1},\ldots,\lambda_{k}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.
8:     \triangleright λisubscript𝜆𝑖\lambda_{i}italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s are softmaxed to enforce the probability simplex constraint
9:until converged. Save the optimal 𝝀={λ1,λ2,,λk}superscript𝝀subscriptsuperscript𝜆1subscriptsuperscript𝜆2superscriptsubscript𝜆𝑘\boldsymbol{\lambda}^{*}=\{\lambda^{*}_{1},\lambda^{*}_{2},\ldots,\lambda_{k}^% {*}\}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = { italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT }.
10:II. Sampling Phase
11:st,𝝀(Y^(t)):=i=1kλist,θi(Y^(t))assignsubscript𝑠𝑡superscript𝝀^𝑌𝑡superscriptsubscript𝑖1𝑘subscriptsuperscript𝜆𝑖superscriptsubscript𝑠𝑡superscript𝜃𝑖^𝑌𝑡s_{t,\boldsymbol{\lambda}^{*}}(\hat{Y}(t)):=\sum_{i=1}^{k}\lambda^{*}_{i}s_{t,% \theta^{*}}^{i}(\hat{Y}(t))italic_s start_POSTSUBSCRIPT italic_t , bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_Y end_ARG ( italic_t ) ) := ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( over^ start_ARG italic_Y end_ARG ( italic_t ) ).
12:Simulate the backward SDE (8) with st,𝝀()subscript𝑠𝑡superscript𝝀s_{t,\boldsymbol{\lambda}^{*}}(\cdot)italic_s start_POSTSUBSCRIPT italic_t , bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( ⋅ ) starting from a Gaussian prior and generate samples.

4 Convergence results

This section details the convergence results for our proposed fusion methods. We focus on sample complexities, quantified by the necessary number of samples in the target dataset, in terms of total variation distance. We show that the sample complexities of our methods are dimension-free, given that the auxiliary processes are accurately fitted to their reference distributions and together offer adequate information for the target distribution. To begin with, we assume all distributions are compactly supported.

Assumption 1.

The target and reference distributions are all compactly supported in 𝕂d𝕂superscript𝑑\mathbb{K}\subset\mathbb{R}^{d}blackboard_K ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with absolutely continuous densities. We assume that their second moments are bounded by M(0,)𝑀0M\in(0,\infty)italic_M ∈ ( 0 , ∞ ).

Proposition 1.

Under Assumption 1, Problem (9) is convex in 𝛌𝛌\boldsymbol{\lambda}bold_italic_λ.

Proposition 1 implies that Problem (9) is easy to solve given that the reference densities can be estimated. We further require Assumption 2 below, which guarantees that each auxilary process is accurately trained in the sense that the score function at each time step is well-fitted.

Assumption 2.

For each 1,2,,k12𝑘1,2,\ldots,k1 , 2 , … , italic_k and for all t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ], logptisubscriptsuperscript𝑝𝑖𝑡\nabla\log p^{i}_{t}∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is L𝐿Litalic_L-Lipschitz with L1𝐿1L\geq 1italic_L ≥ 1 and the step size h=T/N𝑇𝑁h=T/Nitalic_h = italic_T / italic_N satisfies h1/Lless-than-or-similar-to1𝐿h\lesssim 1/Litalic_h ≲ 1 / italic_L; for each 1,,k1𝑘1,\ldots,k1 , … , italic_k and l=0,1,,N𝑙01𝑁l=0,1,\ldots,Nitalic_l = 0 , 1 , … , italic_N, 𝔼plhi[slh,θilogplhi22]ϵscore2subscript𝔼subscriptsuperscript𝑝𝑖𝑙delimited-[]superscriptsubscriptnormsubscriptsuperscript𝑠𝑖𝑙superscript𝜃subscriptsuperscript𝑝𝑖𝑙22superscriptsubscriptitalic-ϵscore2\mathbb{E}_{p^{i}_{lh}}\left[\left\|s^{i}_{lh,\theta^{*}}-\nabla\log p^{i}_{lh% }\right\|_{2}^{2}\right]\leq\epsilon_{\text{score}}^{2}blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l italic_h end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT with small ϵscoresubscriptitalic-ϵscore\epsilon_{\text{score}}italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT.

Assumption 2 is widely used in the diffuion model literature (see, for example, Chen et al. [12]).

To proceed, we denote 𝝀superscript𝝀\boldsymbol{\lambda}^{\ast}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝚲superscript𝚲\boldsymbol{\Lambda}^{\ast}bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to be the solutions of Problems 9 and 10, respectively. Furthermore, the corresponding barycenters are denoted as μ𝝀subscript𝜇superscript𝝀\mu_{\boldsymbol{\lambda}^{\ast}}italic_μ start_POSTSUBSCRIPT bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and μ𝚲subscript𝜇superscript𝚲\mu_{\boldsymbol{\Lambda}^{\ast}}italic_μ start_POSTSUBSCRIPT bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Assumption 3 below states that the theoretical optimal barycenters are close to the target measure, which ensures all reference distributions together are able to provide sufficient information for the target distribution.

Assumption 3.

DKL(νμ𝝀)ϵ02subscript𝐷KLconditional𝜈subscript𝜇superscript𝝀superscriptsubscriptitalic-ϵ02D_{\text{KL}}\left(\nu\parallel\mu_{\boldsymbol{\lambda}^{*}}\right)\leq% \epsilon_{0}^{2}italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_μ start_POSTSUBSCRIPT bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ≤ italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and DKL(νμ𝚲)ϵ12subscript𝐷KLconditional𝜈subscript𝜇superscript𝚲superscriptsubscriptitalic-ϵ12D_{\text{KL}}\left(\nu\parallel\mu_{\boldsymbol{\Lambda}^{*}}\right)\leq% \epsilon_{1}^{2}italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_μ start_POSTSUBSCRIPT bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ≤ italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, with small ϵ0subscriptitalic-ϵ0\epsilon_{0}italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and ϵ1subscriptitalic-ϵ1\epsilon_{1}italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Based on Assumptions 1, 2 and 3, we provide convergence results for the vanilla fusion and ScoreFusion (Algorithm 1) in Theorems 3 and 4, respectively.

Theorem 3.

Suppose that Assumptions 1, 2, and 3 are satisfied. We further assume for each fixed 𝛌Δk𝛌subscriptΔ𝑘\boldsymbol{\lambda}\in\Delta_{k}bold_italic_λ ∈ roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, TV(μ𝛌,μ^𝛌)ϵ2TVsubscript𝜇𝛌subscript^𝜇𝛌subscriptitalic-ϵ2\text{TV}\left(\mu_{\boldsymbol{\lambda}},\hat{\mu}_{\boldsymbol{\lambda}}% \right)\leq\epsilon_{2}TV ( italic_μ start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT , over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ) ≤ italic_ϵ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, where μ^𝛌subscript^𝜇𝛌\hat{\mu}_{\boldsymbol{\lambda}}over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT is the barycenter of the output distributions of k𝑘kitalic_k auxiliary processes. Then, for δ>0𝛿0\delta>0italic_δ > 0 and δ1much-less-than𝛿1\delta\ll 1italic_δ ≪ 1, the output distribution of the vanilla fusion method, ν^Dsubscript^𝜈𝐷\hat{\nu}_{D}over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT, we have with probability at least 1δ1𝛿1-\delta1 - italic_δ,

TV(ν,ν^D)TV𝜈subscript^𝜈𝐷\displaystyle\text{TV}\left(\nu,\hat{\nu}_{D}\right)TV ( italic_ν , over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) ϵ0quality of combined auxiliaries+𝒪((log(1δ))1/4n1/4)mean estimation error+ϵ2auxiliary density estimation+ SE,less-than-or-similar-toabsentsubscriptsubscriptitalic-ϵ0quality of combined auxiliariessubscript𝒪superscript1𝛿14superscript𝑛14mean estimation errorsubscriptsubscriptitalic-ϵ2auxiliary density estimation SE\displaystyle\lesssim\underbrace{\epsilon_{0}}_{\text{quality of combined % auxiliaries}}+\underbrace{{\mathcal{O}\left(\left(\log\left(\frac{1}{\delta}% \right)\right)^{1/4}n^{-1/4}\right)}}_{\text{mean estimation error}}+% \underbrace{\epsilon_{2}}_{\text{auxiliary density estimation}}+\text{ SE},≲ under⏟ start_ARG italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT quality of combined auxiliaries end_POSTSUBSCRIPT + under⏟ start_ARG caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 4 end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT mean estimation error end_POSTSUBSCRIPT + under⏟ start_ARG italic_ϵ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT auxiliary density estimation end_POSTSUBSCRIPT + SE ,

where SE is the error of auxiliary score estimation, defined as

SE=[exp(T)maxi=1,2,,kDKL(pTiπ)+σkT(ϵscore+Ldh+LhM)].𝑆𝐸delimited-[]𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋𝜎𝑘𝑇subscriptitalic-ϵscore𝐿𝑑𝐿𝑀SE={\left[\exp(-T)\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left(p^{i}_{T}% \parallel\pi\right)}+\sigma\sqrt{kT}\left(\epsilon_{\text{score}}+L\sqrt{dh}+% Lh\sqrt{M}\right)\right]}.italic_S italic_E = [ roman_exp ( - italic_T ) roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG + italic_σ square-root start_ARG italic_k italic_T end_ARG ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT + italic_L square-root start_ARG italic_d italic_h end_ARG + italic_L italic_h square-root start_ARG italic_M end_ARG ) ] .
Theorem 4.

Suppose that Assumptions 1, 2, and 3 are satisfied. Then, for δ>0𝛿0\delta>0italic_δ > 0 and δ1much-less-than𝛿1\delta\ll 1italic_δ ≪ 1, for the output distribution of Algorithm 1, ν^Psubscript^𝜈𝑃\hat{\nu}_{P}over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT, with probability at least 1δ1𝛿1-\delta1 - italic_δ,

TV(ν,ν^P)TV𝜈subscript^𝜈𝑃\displaystyle\text{TV}\left(\nu,\hat{\nu}_{P}\right)TV ( italic_ν , over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ) σϵ1quality of combined auxiliaries+𝒪(σ(log(1δ))1/4n1/4)sampling errors+σk𝒪(T~1/4)approximation of time 0+ SE.less-than-or-similar-toabsentsubscript𝜎subscriptitalic-ϵ1quality of combined auxiliariessubscript𝒪𝜎superscript1𝛿14superscript𝑛14sampling errorssubscript𝜎𝑘𝒪superscript~𝑇14approximation of time 0 SE\displaystyle\lesssim\underbrace{\sigma\epsilon_{1}}_{\text{quality of % combined auxiliaries}}+\underbrace{\mathcal{O}\left(\sigma\left(\log\left(% \frac{1}{\delta}\right)\right)^{1/4}n^{-1/4}\right)}_{\text{sampling errors}}+% \underbrace{\sigma\sqrt{k}\mathcal{O}\left(\tilde{T}^{1/4}\right)}_{\text{% approximation of time 0}}+\text{ SE}.≲ under⏟ start_ARG italic_σ italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT quality of combined auxiliaries end_POSTSUBSCRIPT + under⏟ start_ARG caligraphic_O ( italic_σ ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 4 end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT sampling errors end_POSTSUBSCRIPT + under⏟ start_ARG italic_σ square-root start_ARG italic_k end_ARG caligraphic_O ( over~ start_ARG italic_T end_ARG start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT approximation of time 0 end_POSTSUBSCRIPT + SE .

Theorems 3 and 4 demonstrate dimension-free sample complexities given that auxiliaries are well approximated and auxiliaries all combined capture the features of target well. More specifically, each bound in Theorems 3 and 4 has 4 terms, which represents different sources of error.

The quality of combined auxiliaries is the essential assumption in both Theorems 3 and 4. The sampling error in Theorem 4 reflects the fact that with the help of diffusion models, the optimization in fact becomes linear in terms of scores, making the problem easier and escape the curse of dimensionality. The the approximation to time t=0𝑡0t=0italic_t = 0 term replaces the vanilla fusion with a small controllable noise but makes the implementation much easier. It worth noticing that there is a tradeoff between choosing T~~𝑇\tilde{T}over~ start_ARG italic_T end_ARG: the smaller T~~𝑇\tilde{T}over~ start_ARG italic_T end_ARG, the more accurate the optimal weights are, but the more probably that the algorithm will encounter numerical instability. Finally, the score estimation term of the auxiliaries can be small with a careful choice of discretization time steps and accurate auxiliary score approximation (see Remark 3 in Appendix C.2).

5 Numerical results

We implement ScoreFusion model and examine its performance on both synthetic and real-world image datasets. The auxiliary score functions uses the same U-Net backbone as the code repository of Song et al. [47] for score-based diffusion. Our experiments vary the quantity of training data available to ScoreFusion and the baseline, which is a regular score-based diffusion model. We aim to demonstrate that in low data regime, using ScoreFusion outperforms training a score model from scratch. This section summarizes key experiment findings, leaving implementation details and additional data to Appendix D.

5.1 Bimodal Gaussian mixture distributions

We test ScoreFusion’s ability to approximate am one-dimensional bimodal Gaussian mixture distribution using two auxillary distributions. Since the data is synthetic, we can access the true density function of the target distribution and auxiliary distributions, shown in the right of Figure 2; the ground truth distribution is in grey. Table 1 gives the 1111-Wasserstein distance 𝒲1subscript𝒲1\mathcal{W}_{1}caligraphic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT between the distribution learned by ScoreFusion and the ground truth distribution, calculated using SciPy.

Table 1: 1-Wasserstein distance from the ground truth distribution. Standard error is calculated from the 𝒲1subscript𝒲1\mathcal{W}_{1}caligraphic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT distances of 10101010 random draws of 8096809680968096 samples from each generator.
  Model 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT
Baseline 106.93±1.43plus-or-minus106.931.43106.93\pm 1.43106.93 ± 1.43 13.46±0.28plus-or-minus13.460.2813.46\pm 0.2813.46 ± 0.28 16.74±0.27plus-or-minus16.740.2716.74\pm 0.2716.74 ± 0.27 0.55±0.04plus-or-minus0.550.040.55\pm 0.040.55 ± 0.04 0.15±0.02plus-or-minus0.150.02\mathbf{0.15\pm 0.02}bold_0.15 ± bold_0.02
ScoreFusion 0.39±0.02plus-or-minus0.390.02\mathbf{0.39\pm 0.02}bold_0.39 ± bold_0.02 0.51±0.03plus-or-minus0.510.03\mathbf{0.51\pm 0.03}bold_0.51 ± bold_0.03 0.36±0.02plus-or-minus0.360.02\mathbf{0.36\pm 0.02}bold_0.36 ± bold_0.02 0.38±0.02plus-or-minus0.380.02\mathbf{0.38\pm 0.02}bold_0.38 ± bold_0.02 0.30±0.02plus-or-minus0.300.020.30\pm 0.020.30 ± 0.02
 
Refer to caption
Refer to caption
Figure 2: Left: Histograms of 8096809680968096 ScoreFusion samples and 8096809680968096 baseline samples; both models are trained on 64646464 samples. Right: Density functions of ground truth vs. the auxiliary distributions.

Using only 64646464 training data, ScoreFusion can already learn a good representation of the ground truth distribution. In contrast, the standard diffusion model is overly widespread and fails to capture the modes of the Gaussian mixture. Moreover, ScoreFusion consistently outperforms the baseline in 𝒲1subscript𝒲1\mathcal{W}_{1}caligraphic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT distance when the number of training data is fewer than 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT.

5.2 EMNIST with heterogeneous digits mix

We further demonstrate our algorithm on the EMNIST dataset [15], an augmentation of the original MNIST dataset comprising handwritten digits in 1x28x28 format. We selected five subsets (Disubscript𝐷𝑖D_{i}italic_D start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, i=1,,5𝑖15i=1,\ldots,5italic_i = 1 , … , 5) from EMNIST, focusing on the digits 7 and 9 with varying frequencies: (10%,90%)percent10percent90(10\%,90\%)( 10 % , 90 % ), (30%,70%)percent30percent70(30\%,70\%)( 30 % , 70 % ), (70%,30%)percent70percent30(70\%,30\%)( 70 % , 30 % ), (90%,10%)percent90percent10(90\%,10\%)( 90 % , 10 % ), and (60%,40%)percent60percent40(60\%,40\%)( 60 % , 40 % ). D1,,D4subscript𝐷1subscript𝐷4D_{1},\ldots,D_{4}italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_D start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT serve as auxiliary datasets for training the auxiliary scores, each rich in training samples to ensure adequate training of the auxiliaries. D5subscript𝐷5D_{5}italic_D start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT is used as the target dataset for training both ScoreFusion and the baseline, with variations in training and validation data to assess comparative test performance.

Two metrics are used to evaluate image samples generated by different models. First is the Negative Log Likelihood (NLL), measured on a held-out test dataset and expressed in bits per dimension (bpd) [50]; a smaller NLL implies that test images are more likely samples from the trained generative model, and is a standard metric for evaluating diffusion models [24, 47]. Table 2 displays the results for target sample sizes ranging from 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT to 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT, which shows that the ScoreFusion model can generalize to test samples much better than the baseline diffusion model in the low-data regime.

Table 2: Mean NLL (test) under different counts of training data
  Sample size 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT
Baseline 7.186±0.019plus-or-minus7.1860.0197.186\pm 0.0197.186 ± 0.019 6.235±0.016plus-or-minus6.2350.0166.235\pm 0.0166.235 ± 0.016 5.725±0.024plus-or-minus5.7250.0245.725\pm 0.0245.725 ± 0.024 4.979±0.028plus-or-minus4.9790.0284.979\pm 0.0284.979 ± 0.028
Single auxiliary 4.768±0.024plus-or-minus4.7680.0244.768\pm 0.0244.768 ± 0.024
ScoreFusion 4.733±0.029plus-or-minus4.7330.029\mathbf{4.733\pm 0.029}bold_4.733 ± bold_0.029 4.733±0.018plus-or-minus4.7330.018\mathbf{4.733\pm 0.018}bold_4.733 ± bold_0.018 4.718±0.022plus-or-minus4.7180.022\mathbf{4.718\pm 0.022}bold_4.718 ± bold_0.022 4.715±0.021plus-or-minus4.7150.021\mathbf{4.715\pm 0.021}bold_4.715 ± bold_0.021
 

The second metric examines the digit class distribution of generated samples, a discrete distribution over ten classes. This metric is related to the idea of sample diversity as explained in Naeem et al. [32]. To estimate the ratio of digits in the samples, we train an image classifier called SpinalNet [27] on the entire EMNIST digits class, achieving a 99.5%percent99.599.5\%99.5 % classification accuracy. At evaluation, we sample 1024102410241024 images from a trained generative model, feed them into the pre-trained SpinalNet, and average the outputs (i.e. mean of 1024102410241024 length-10101010 softmaxed logits vectors) to approximate the generative model’s digits distribution. A comparison is given in Table 3. ScoreFusion consistently mirror the proportion of 7’s and 9’s of the ground truth dataset where the baseline struggles, an impressive result given that this metric was not explicitly optimized in the training of ScoreFusion.

Table 3: Digits distribution estimated by SpinalNet. Bolded columns are the breakdown for ScoreFusion. “Others” category refers to fraction of samples resembling digits other than the 7’s or 9’s more.
  Digit True 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT
Baseline Fusion Baseline Fusion Baseline Fusion Baseline Fusion
7 60% 47.9% 55.6% 66.8% 57.5% 65.5% 56.6% 66.7% 59.8%
9 40% 10.3% 39.4% 23.8% 38.0% 26.7% 39.8% 27.9% 36.7%
Others 0 41.8% 5.0% 9.4% 4.5% 7.8% 3.6% 5.4% 3.5%
 

Finally, we present generated images from the baseline diffusion model and ScoreFusion in Figure 3. With only 64 training images, ScoreFusion can already produce high-quality digits, while the baseline diffusion method generates unrecognizable images. ScoreFusion also outperforms the baseline with 256 training images, producing clearer and more accurate digits.

Refer to caption
Figure 3: Samples created by the baseline and ScoreFusion trained on 64646464 and 256256256256 images.

6 Conclusion

In this paper, we propose a fusion method based on KL barycenter that can be easily implemented if the auxiliary score estimations are obtained from diffusion. We provide a theoretical analysis of the sample complexity, showing that it is dimension-free given accurate auxiliary score estimation and closeness between optimal KL barycenter and the target distribution. The numerical experiments further demonstrate that our fusion method performs much better than the basic diffusion model in the low data regime. This work forms a basic starting point of approximating target when data is limited using the method of fusion, in which diffusion model makes the implementation much easier. More broadly, the fusion methods may be applied to other variants in diffusion models family, including different assumptions on initial distributions [17, 18, 8], other neural network structures, [11], Schrödinger bridges [30, 54, 18] etc.

Acknowledgments

This work is supported generously by the NSF grants CCF-2312204 and CCF-2312205 and Air Force Office of Scientific Research FA9550-20-1-0397. Additional support is gratefully acknowledged from NSF 1915967, 2118199, and 2229012.

References

  • Agueh and Carlier [2011] Martial Agueh and Guillaume Carlier. Barycenters in the wasserstein space. SIAM Journal on Mathematical Analysis, 43(2):904–924, 2011.
  • Ambrosio et al. [2005] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré. Gradient Flows. Springer Science &\&& Business Media, 2005.
  • Anderson [1982] Brian D. O. Anderson. Reverse-time diffusion equation models. Stochastic Processes and their Applications, 12(3):313–326, 1982.
  • Austin et al. [2021] Jacob Austin, Daniel D Johnson, Jonathan Ho, Daniel Tarlow, and Rianne van den Berg. Structured denoising diffusion models in discrete state-spaces. In Advances in Neural Information Processing Systems 34 (NeurIPS 2021). NeurIPS, 2021.
  • Avrahami et al. [2022] Omri Avrahami, Dani Lischinski, and Ohad Fried. Blended diffusion for text-driven editing of natural images. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022.
  • Banerjee et al. [2005] Arindam Banerjee, Inderjit S. Dhillon, Joydeep Ghosh, and Suvrit Sra. Clustering on the unit hypersphere using von mises-fisher distributions. Journal of Machine Learning Research, 6:1345–1382, 2005.
  • Benamou et al. [2015] Jean-David Benamou, Guillaume Carlier, Marco Cuturi, Luca Nenna, and Gabriel Peyré. Iterative bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2):A1111–A1138, 2015.
  • Block et al. [2022] Alexander Block, Youssef Mroueh, and Alexander Rakhlin. Generative modeling with denoising auto-encoders and langevin sampling. arXiv e-prints, 2022.
  • Braun et al. [2022] Gábor Braun, Alejandro Carderera, Cyrille W. Combettes, Hamed Hassani, Amin Karbasi, Aryan Mokhtari, and Sebastian Pokutta. Conditional gradient methods. arXiv preprint arXiv:2211.14103, 2022.
  • Cattiaux et al. [2022] Patrick Cattiaux, Giovanni Conforti, Ivan Gentil, and Christian Léonard. Time reversal of diffusion processes under a finite entropy condition. arXiv preprint arXiv:2104.07708, 2022.
  • Chen et al. [2023a] Minshuo Chen, Kaixuan Huang, Tuo Zhao, and Mengdi Wang. Score approximation, estimation and distribution recovery of diffusion models on low-dimensional data. In Proceedings of the 40th International Conference on Machine Learning, volume 202, pages 4672–4712. PMLR, 2023a.
  • Chen et al. [2023b] Sitan Chen, Sinho Chewi, Jerry Li, Yuanzhi Li, Adil Salim, and Anru R. Zhang. Sampling is as easy as learning the score: theory for diffusion models with minimal data assumptions. arXiv preprint arXiv:2209.1121, 2023b.
  • Claici et al. [2018] Sebastian Claici, Edward Chien, and Justin Solomon. Stochastic wasserstein barycenters. In International Conference on Machine Learning, pages 1141–1150, 2018.
  • Claici et al. [2020] Sebastian Claici, Mikhail Yurochkin, Soumya Ghosh, and Justin Solomon. Model fusion with kullback-leibler divergence. In International Conference on Machine Learning. PMLR, 2020.
  • Cohen et al. [2017] Gregory Cohen, Saeed Afshar, Jonathan Tapson, and Andre Van Schaik. Emnist: Extending mnist to handwritten letters. In 2017 international joint conference on neural networks (IJCNN), pages 2921–2926. IEEE, 2017.
  • Cuturi and Doucet [2014] Marco Cuturi and Arnaud Doucet. Fast computation of wasserstein barycenters. In International Conference on Machine Learning, pages 685–693, 2014.
  • De Bortoli [2022] Valentin De Bortoli. Convergence of denoising diffusion models under the manifold hypothesis. Transactions on Machine Learning Research, 2022.
  • De Bortoli et al. [2021] Valentin De Bortoli, Jacob Thornton, Jeremy Heng, and Arnaud Doucet. Diffusion schrödinger bridge with applications to score-based generative modeling. In Advances in Neural Information Processing Systems, volume 34, pages 17695–17709. Curran Associates, Inc., 2021.
  • Evans [2010] Lawrence C. Evans. Partial Differential Equations. American Mathematical Society, 2010.
  • Föllmer [1985] H Föllmer. An entropy approach to the time reversal of diffusion processes. Stochastic differential systems (Marseille-Luminy, 1984), 69:156–163, 1985.
  • Gelfand and Fomin [2000] I. M. Gelfand and S. V. Fomin. Calculus of Variations. Dover Publications, 2000.
  • Genevay et al. [2016] Aude Genevay, Marco Cuturi, Gabriel Peyré, and Francis Bach. Stochastic optimization for large-scale optimal transport. Advances in Neural Information Processing Systems, 29:3432–3440, 2016.
  • Gong et al. [2022] Shansan Gong, Mukai Li, Jiangtao Feng, Zhiyong Wu, and LingPeng Kong. Diffuseq: Sequence to sequence text generation with diffusion models. arXiv preprint arXiv::2210.08933, 2022.
  • Ho et al. [2020] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840–6851, 2020.
  • Hsu et al. [2021] Daniel Hsu, Clayton Sanford, and Rocco A. Servedio Emmanouil V. Vlatakis Gkaragkounis. On the approximation power of two-layer networks of random relus. In Proceedings of Machine Learning Research, volume 134, pages 1–39. 34th Annual Conference on Learning Theory, 2021.
  • Janati et al. [2020] Hicham Janati, Marco Cuturi, and Alexandre Gramfort. Debiased sinkhorn barycenters. In International Conference on Machine Learning, pages 4692–4701, 2020.
  • Kabir et al. [2022] HM Dipu Kabir, Moloud Abdar, Abbas Khosravi, Seyed Mohammad Jafar Jalali, Amir F Atiya, Saeid Nahavandi, and Dipti Srinivasan. Spinalnet: Deep neural network with gradual input. IEEE Transactions on Artificial Intelligence, 2022.
  • Kuznetsova et al. [2020] Alina Kuznetsova, Hassan Rom, Neil Alldrin, Jasper Uijlings, Ivan Krasin, Jordi Pont-Tuset, Shahab Kamali, Stefan Popov, Matteo Malloci, Alexander Kolesnikov, Tom Duerig, and Vittorio Ferrari. The open images dataset v4: Unified image classification, object detection, and visual relationship detection at scale. In International Journal of Computer Vision. ICCV, 2020.
  • Li et al. [2023] Puheng Li, Zhong Li, Huishuai Zhang, and Jiang Bian. On the generalization properties of diffusion models. In Advances in Neural Information Processing Systems 36 (NeurIPS 2023). NeurIPS, 2023.
  • Liu et al. [2022] Hongjun Liu, Xiang Zhang, and Qionghai Li. Sb-ddpm: Schrödinger bridge diffusion denoising probabilistic model for generative tasks. IEEE Transactions on Neural Networks and Learning Systems, 2022.
  • Mokady et al. [2022] Ron Mokady, Amir Hertz, Kfir Aberman, Yael Pritch, and Daniel Cohen-Or. Null-text inversion for editing real images using guided diffusion models. arXiv preprint arXiv:2211.09794, 2022.
  • Naeem et al. [2020] Muhammad Ferjad Naeem, Seong Joon Oh, Youngjung Uh, Yunjey Choi, and Jaejun Yoo. Reliable fidelity and diversity metrics for generative models. In International Conference on Machine Learning, pages 7176–7185. PMLR, 2020.
  • Omri Avrahami [2022] Dani Lischinski Omri Avrahami, Ohad Fried. Blended latent diffusion. arXiv preprint arXiv:2206.02779, 2022.
  • Pan and Yang [2009] Sinno Jialin Pan and Qiang Yang. A survey on transfer learning. IEEE Transactions on knowledge and data engineering, 22(10):1345–1359, 2009.
  • Peyré and Cuturi [2019] Gabriel Peyré and Marco Cuturi. Computational optimal transport. Foundations and Trends in Machine Learning, 11(5-6):355–607, 2019.
  • Peyré et al. [2016] Gabriel Peyré, Marco Cuturi, and Justin Solomon. Gromov-wasserstein averaging of kernel and distance matrices. In International Conference on Machine Learning, pages 2664–2672, 2016.
  • Popov et al. [2021] Vadim Popov, Ivan Vovk, Vladimir Gogoryan, Tasnima Sadekova, , and Mikhail Kudinov. Grad-tts. A diffusion probabilistic model for text-to-speech. In International Conference on Learning Representations, 2021.
  • Ramesh et al. [2022] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, and Mark Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 2022.
  • Rasul et al. [2021] Kashif Rasul, Calvin Seward, Ingmar Schuster, and Roland Vollgraf. Autoregressive denoising diffusion models for multivariate probabilistic time series forecasting. In International Conference on Learning Representations, 2021.
  • Rombach et al. [2022] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, , and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022.
  • Saharia et al. [2022a] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Raphael Gontijo-Lopes, Burcu Karagol Ayan, Tim Salimans, Jonathan Ho, David J. Fleet, and Mohammad Norouzi. Photorealistic text-to-image diffusion models with deep language understanding. In Advances in Neural Information Processing Systems 36 (NeurIPS 2022). NeurIPS, 2022a.
  • Saharia et al. [2022b] Chitwan Saharia, William Chan, Saurabh Saxena†, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David J Fleet, and Mohammad Norouzi. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
  • Schuhmann et al. [2022] Christoph Schuhmann, Romain Beaumont, Richard Vencu, Cade Gordon, Ross Wightman, Mehdi Cherti, Theo Coombes, Aarush Katta, Clayton Mullis, Mitchell Wortsman, Patrick Schramowski, Srivatsa Kundurthy, Katherine Crowson, Ludwig Schmidt, Robert Kaczmarczyk, and Jenia Jitsev. Laion-5b: An open large-scale dataset for training next generation image-text models. In Advances in Neural Information Processing Systems 36. NeurIPS, 2022.
  • Sohl-Dickstein et al. [2015] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. Proceedings of the 32nd International Conference on Machine Learning, 37:2256–2265, 2015.
  • Solomon et al. [2015] Justin Solomon, Fernando de Goes, Gabriel Peyré, Marco Cuturi, Adrian Butscher, Andy Nguyen, and Leonidas Guibas. Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics, 34(4):66, 2015.
  • Song et al. [2021a] Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. Maximum likelihood training of score-based diffusion models. In Advances in Neural Information Processing Systems 35 (NeurIPS 2021). NeurIPS, 2021a.
  • Song et al. [2021b] Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modelling through stochastic differential equations. In International Conference on Learning Representations (ICLR 2021). ICLR, 2021b.
  • Staib et al. [2017] Matthew Staib, Sebastian Claici, Justin Solomon, and Stefanie Jegelka. Parallel streaming wasserstein barycenters. In Advances in Neural Information Processing Systems, volume 30, pages 2647–2658, 2017.
  • Tan et al. [2020] Chuanqi Tan, Fuchun Sun, Tao Kong, Wenchang Zhang, Chao Yang, and Chunfang Liu. A survey on deep transfer learning. Artificial Intelligence Review, 52(2):1089–1116, 2020.
  • Theis et al. [2016] L Theis, A van den Oord, and M Bethge. A note on the evaluation of generative models. In International Conference on Learning Representations (ICLR 2016), pages 1–10, 2016.
  • Torrey and Shavlik [2010] Lisa Torrey and Jude Shavlik. Transfer learning. Handbook of Research on Machine Learning Applications and Trends: Algorithms, Methods, and Techniques, pages 242–264, 2010.
  • Tumanyan et al. [2022] Narek Tumanyan, Michal Geyer, Shai Bagon, and Tali Dekel. Plug-and-play diffusion features for text-driven image-to-image translation. arXiv preprint arXiv:2211.12572, 2022.
  • van Handel [2016] Ramon van Handel. Probability in high dimension, apc 550 lecture notes, December 2016.
  • Vargas et al. [2022] Francisco J. Vargas, James E. Taylor, and Valentin de Bortoli. Solving schrödinger bridges via maximum likelihood: Applications to diffusion-based generative modeling. In Proceedings of the International Conference on Learning Representations, 2022.
  • Wang et al. [2023] Zhendong Wang, Yifan Jiang, Huangjie Zheng, Peihao Wang, Pengcheng He, Zhangyang Wang, Weizhu Chen, and Mingyuan Zhou. Patch diffusion: Faster and more data-efficient training of diffusion models. arXiv preprint arXiv:2304.12526, 2023.
  • Weiss et al. [2016] Karl Weiss, Taghi M Khoshgoftaar, and DingDing Wang. A survey of transfer learning. In Journal of Big Data, volume 3, pages 1–40. Springer, 2016.
  • Wu et al. [2022] Jay Zhangjie Wu, Yixiao Ge, Xintao Wang, Weixian Lei, Yuchao Gu, Wynne Hsu, Ying Shan, Xiaohu Qie, and Mike Zheng Shou. Tune-a-video: One-shot tuning of image diffusion models for text-to-video generation. arXiv preprint arXiv:2212.11565, 2022.
  • Zhang et al. [2023] Ruoyu Zhang, Yanzeng Li, Yongliang Ma, Ming Zhou, and Lei Zou. Llmaaa: Making large language models as active annotators. arXiv preprint arXiv:2310.19596, 2023.
  • Zhu et al. [2023] **gyuan Zhu, Huimin Ma, Jiansheng Chen, and Jian Yuan. Domainstudio: Fine-tuning diffusion models for domain-driven image generation using limited data. arXiv preprint arXiv:2306.14153, 2023.
  • Zhuang et al. [2021] Fuzhen Zhuang, Ziliang Qi, Keyu Duan, Dongbo Xi, Yongchun Zhu, Hengshu Zhu, Hui Xiong, and Qing He. A comprehensive survey on transfer learning. In Proceedings of the IEEE, volume 109, pages 43–76. IEEE, 2021.

Appendix A More about basic diffusion models

A.1 About the time reversal formula

Note that Equations (3) and (4) are still represented as a “forward” processes. If we replace W(t)𝑊𝑡W(t)italic_W ( italic_t ) by W~(t)~𝑊𝑡\tilde{W}(t)over~ start_ARG italic_W end_ARG ( italic_t ), where W~(t)~𝑊𝑡\tilde{W}(t)over~ start_ARG italic_W end_ARG ( italic_t ) is a standard d𝑑ditalic_d-dimensional Brownian motion flows backward from time T𝑇Titalic_T to 0, then Equation (3) becomes

dX^(t)=(f(Tt,X^(t))g2(Tt)logpTt(X^(t)))dt+g(Tt)dW~(t),X^(T)pT,formulae-sequence𝑑^𝑋𝑡𝑓𝑇𝑡^𝑋𝑡superscript𝑔2𝑇𝑡subscript𝑝𝑇𝑡^𝑋𝑡𝑑𝑡𝑔𝑇𝑡𝑑~𝑊𝑡similar-to^𝑋𝑇subscript𝑝𝑇d\hat{X}(t)=\left(f(T-t,\hat{X}(t))-g^{2}(T-t)\nabla\log p_{T-t}\left(\hat{X}(% t)\right)\right)dt+g(T-t)d\tilde{W}(t),\hat{X}(T)\sim p_{T},italic_d over^ start_ARG italic_X end_ARG ( italic_t ) = ( italic_f ( italic_T - italic_t , over^ start_ARG italic_X end_ARG ( italic_t ) ) - italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_T - italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over^ start_ARG italic_X end_ARG ( italic_t ) ) ) italic_d italic_t + italic_g ( italic_T - italic_t ) italic_d over~ start_ARG italic_W end_ARG ( italic_t ) , over^ start_ARG italic_X end_ARG ( italic_T ) ∼ italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ,

which is the reverse SDE presented in Song et al. [47]. Hence for the forward OU process, the reverse process has another representation by

dX^(t)=(aX^(t)σ2logpTt(X^(t)))dt+σdW~(t),X^(T)pT.formulae-sequence𝑑^𝑋𝑡𝑎^𝑋𝑡superscript𝜎2subscript𝑝𝑇𝑡^𝑋𝑡𝑑𝑡𝜎𝑑~𝑊𝑡similar-to^𝑋𝑇subscript𝑝𝑇d\hat{X}(t)=\left(-a\hat{X}(t)-\sigma^{2}\nabla\log p_{T-t}\left(\hat{X}(t)% \right)\right)dt+\sigma d\tilde{W}(t),\hat{X}(T)\sim p_{T}.italic_d over^ start_ARG italic_X end_ARG ( italic_t ) = ( - italic_a over^ start_ARG italic_X end_ARG ( italic_t ) - italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over^ start_ARG italic_X end_ARG ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d over~ start_ARG italic_W end_ARG ( italic_t ) , over^ start_ARG italic_X end_ARG ( italic_T ) ∼ italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT .

A.2 Discretization and backward sampling

In this section, we follow the scheme in Chen et al. [12].

Given n𝑛nitalic_n samples X0(1),,X0(n)superscriptsubscript𝑋01superscriptsubscript𝑋0𝑛X_{0}^{(1)},\ldots,X_{0}^{(n)}italic_X start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , italic_X start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_n ) end_POSTSUPERSCRIPT from p0subscript𝑝0p_{0}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (data distribution), we train a neural network with the loss function (5). Let h>00h>0italic_h > 0 be the step size of the time discretization, and there are N𝑁Nitalic_N steps, hence T=Nh𝑇𝑁T=Nhitalic_T = italic_N italic_h. We assume that for each time l=0,1,,N𝑙01𝑁l=0,1,\ldots,Nitalic_l = 0 , 1 , … , italic_N, the score estimation slh,θsubscript𝑠𝑙superscript𝜃s_{lh,\theta^{*}}italic_s start_POSTSUBSCRIPT italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT of logptsubscript𝑝𝑡\nabla\log p_{t}∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is obtained. In order to simulate the reverse SDE (3), we first replace the score function logpTtsubscript𝑝𝑇𝑡\nabla\log p_{T-t}∇ roman_log italic_p start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT with the estimate sTt,θsubscript𝑠𝑇𝑡superscript𝜃s_{T-t,\theta^{*}}italic_s start_POSTSUBSCRIPT italic_T - italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Next, for each t[lh,(l+1)h]𝑡𝑙𝑙1t\in[lh,(l+1)h]italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ], the value of this coefficient in the SDE at time lh𝑙lhitalic_l italic_h, which yields the new time-discretized SDE with each t[lh,(l+1)h]𝑡𝑙𝑙1t\in[lh,(l+1)h]italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ],

dX^(t)=(f(Tt,X^(t))+g2(Tt)sTt,θ(X^kh))dt+g(Tt)dW(t)𝑑^𝑋𝑡𝑓𝑇𝑡^𝑋𝑡superscript𝑔2𝑇𝑡subscript𝑠𝑇𝑡superscript𝜃subscript^𝑋𝑘𝑑𝑡𝑔𝑇𝑡𝑑𝑊𝑡d\hat{X}(t)=\left(-f(T-t,\hat{X}(t))+g^{2}(T-t)s_{T-t,\theta^{*}}\left(\hat{X}% _{kh}\right)\right)dt+g(T-t)dW(t)italic_d over^ start_ARG italic_X end_ARG ( italic_t ) = ( - italic_f ( italic_T - italic_t , over^ start_ARG italic_X end_ARG ( italic_t ) ) + italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_T - italic_t ) italic_s start_POSTSUBSCRIPT italic_T - italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_k italic_h end_POSTSUBSCRIPT ) ) italic_d italic_t + italic_g ( italic_T - italic_t ) italic_d italic_W ( italic_t ) (11)

and X^(0)Πsimilar-to^𝑋0Π\hat{X}(0)\sim\Piover^ start_ARG italic_X end_ARG ( 0 ) ∼ roman_Π, where ΠΠ\Piroman_Π is the (theoretical) stationary distribution of the forward process (1).

There are several details in this implementation. In practice, when we use OU process as the forward, then Equation (11) becomes

dX^(t)=(aX^(t)+σ2sTt,θ(X^kh))dt+σdW(t),t[lh,(l+1)h],formulae-sequence𝑑^𝑋𝑡𝑎^𝑋𝑡superscript𝜎2subscript𝑠𝑇𝑡superscript𝜃subscript^𝑋𝑘𝑑𝑡𝜎𝑑𝑊𝑡𝑡𝑙𝑙1d\hat{X}(t)=\left(a\hat{X}(t)+\sigma^{2}s_{T-t,\theta^{*}}\left(\hat{X}_{kh}% \right)\right)dt+\sigma dW(t),t\in[lh,(l+1)h],italic_d over^ start_ARG italic_X end_ARG ( italic_t ) = ( italic_a over^ start_ARG italic_X end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT italic_T - italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_k italic_h end_POSTSUBSCRIPT ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ] ,

with Π=πΠ𝜋\Pi=\piroman_Π = italic_π, which is a linear SDE. In particular, X(l+1)hsubscript𝑋𝑙1X_{(l+1)h}italic_X start_POSTSUBSCRIPT ( italic_l + 1 ) italic_h end_POSTSUBSCRIPT conditioned on Xlhsubscript𝑋𝑙X_{lh}italic_X start_POSTSUBSCRIPT italic_l italic_h end_POSTSUBSCRIPT is Gaussian, so the sampling is easier.

In theory, we should use ΠpTsimilar-toΠsubscript𝑝𝑇\Pi\sim p_{T}roman_Π ∼ italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT, which we have no access to. The above implementation takes advantage of pTΠsubscript𝑝𝑇Πp_{T}\approx\Piitalic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ≈ roman_Π as T𝑇Titalic_T is large enough. This introduces a small initialization error.

A.3 About the generalization error of basic diffusion model

In Li et al. [29], a random feature model is considered as the score estimator. The basic intuition is that the generalization error with respect to the KL divergence, DKL(μμ^)subscript𝐷KLconditional𝜇^𝜇D_{\text{KL}}\left(\mu\parallel\hat{\mu}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ ∥ over^ start_ARG italic_μ end_ARG ) is decomposed into three terms: the training error, approximation error of underlying random feature model, and the convergence error of stationary measures. Among these three, the third one is ignorable since the fast rate of convergence of an OU process (or, from log Sobolev inequality for Gaussian random variables in van Handel [53]). The first one is also small since random feature model in this setting is essentially linear regression with least squares.

Moreover, as stated in Hsu et al. [25], random feature model can approximate Lipschitz functions with compact supports. However, the approximation error can be large and cause curse of dimensionality if we choose mnsimilar-to𝑚𝑛m\sim nitalic_m ∼ italic_n. To illustrate this, we make a more general statement including smoothness considerations.

To be more precise, we introduce the following setting. We use the basic diffusion model with a forward OU process. The score function st,θ(x)subscript𝑠𝑡𝜃𝑥s_{t,\theta}(x)italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT ( italic_x ) is parameterized by the random feature model with m𝑚mitalic_m random features:

st,θ(x)=1mAσ(Wx+Ue(t))=1mj=1majσ(wjTx+ujTe(t)),subscript𝑠𝑡𝜃𝑥1𝑚𝐴𝜎𝑊𝑥𝑈𝑒𝑡1𝑚superscriptsubscript𝑗1𝑚subscript𝑎𝑗𝜎superscriptsubscript𝑤𝑗𝑇𝑥superscriptsubscript𝑢𝑗𝑇𝑒𝑡s_{t,\theta}(x)=\frac{1}{m}A\sigma\left(Wx+Ue(t)\right)=\frac{1}{m}\sum_{j=1}^% {m}a_{j}\sigma\left(w_{j}^{T}x+u_{j}^{T}e(t)\right),italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT ( italic_x ) = divide start_ARG 1 end_ARG start_ARG italic_m end_ARG italic_A italic_σ ( italic_W italic_x + italic_U italic_e ( italic_t ) ) = divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_σ ( italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_e ( italic_t ) ) ,

where σ𝜎\sigmaitalic_σ is the ReLU activation function, A=(a1,,am)d×m𝐴subscript𝑎1subscript𝑎𝑚superscript𝑑𝑚A=(a_{1},\ldots,a_{m})\in\mathbb{R}^{d\times m}italic_A = ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_m end_POSTSUPERSCRIPT is the trainable parameters, W=(w1,,wm)Tm×d𝑊superscriptsubscript𝑤1subscript𝑤𝑚𝑇superscript𝑚𝑑W=(w_{1},\ldots,w_{m})^{T}\in\mathbb{R}^{m\times d}italic_W = ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_w start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT, U=(u1,,um)Tm×de𝑈superscriptsubscript𝑢1subscript𝑢𝑚𝑇superscript𝑚subscript𝑑𝑒U=(u_{1},\ldots,u_{m})^{T}\in\mathbb{R}^{m\times d_{e}}italic_U = ( italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are initially sampled from some pre-chosen distributions (related to random features) and remain frozen during the training, and e:+de:𝑒subscriptsuperscriptsubscript𝑑𝑒e:\mathbb{R}_{+}\to\mathbb{R}^{d_{e}}italic_e : blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the time embedding function. The precise description is given below.

Assume that aj,wj,subscript𝑎𝑗subscript𝑤𝑗a_{j},w_{j},italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , and ujsubscript𝑢𝑗u_{j}italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are drawn i.i.d. from a distribution ρ𝜌\rhoitalic_ρ, then as m𝑚m\to\inftyitalic_m → ∞, from strong law of large numbers, with probability 1,

st,θ(x)s¯t,θ¯(x)=𝔼(w,u)ρ0[a(w,u)σ(wTx+uTe(t))],subscript𝑠𝑡𝜃𝑥subscript¯𝑠𝑡¯𝜃𝑥subscript𝔼similar-to𝑤𝑢subscript𝜌0delimited-[]𝑎𝑤𝑢𝜎superscript𝑤𝑇𝑥superscript𝑢𝑇𝑒𝑡s_{t,\theta}(x)\to\bar{s}_{t,\bar{\theta}}(x)=\mathbb{E}_{(w,u)\sim\rho_{0}}% \left[a(w,u)\sigma\left(w^{T}x+u^{T}e(t)\right)\right],italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT ( italic_x ) → over¯ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_t , over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ( italic_x ) = blackboard_E start_POSTSUBSCRIPT ( italic_w , italic_u ) ∼ italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_a ( italic_w , italic_u ) italic_σ ( italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + italic_u start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_e ( italic_t ) ) ] , (12)

where a(w,u)=1ρ0(w,u)aρ(a,w,u)𝑑a𝑎𝑤𝑢1subscript𝜌0𝑤𝑢𝑎𝜌𝑎𝑤𝑢differential-d𝑎a(w,u)=\frac{1}{\rho_{0}(w,u)}\int a\rho(a,w,u)daitalic_a ( italic_w , italic_u ) = divide start_ARG 1 end_ARG start_ARG italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_w , italic_u ) end_ARG ∫ italic_a italic_ρ ( italic_a , italic_w , italic_u ) italic_d italic_a and ρ0(w,u)=ρ(a,w,u)𝑑asubscript𝜌0𝑤𝑢𝜌𝑎𝑤𝑢differential-d𝑎\rho_{0}(w,u)=\int\rho(a,w,u)daitalic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_w , italic_u ) = ∫ italic_ρ ( italic_a , italic_w , italic_u ) italic_d italic_a. From the positive homogeneity of ReLU function, we may assume u+w1delimited-∥∥𝑢delimited-∥∥𝑤1\left\lVert u\right\rVert+\left\lVert w\right\rVert\leq 1∥ italic_u ∥ + ∥ italic_w ∥ ≤ 1. The optimal solution is denoted by θ¯¯superscript𝜃\bar{\theta^{*}}over¯ start_ARG italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG when replacing st,θ(x)subscript𝑠𝑡𝜃𝑥s_{t,\theta}(x)italic_s start_POSTSUBSCRIPT italic_t , italic_θ end_POSTSUBSCRIPT ( italic_x ) in loss objective with s¯t,θ¯(x)subscript¯𝑠𝑡¯𝜃𝑥\bar{s}_{t,\bar{\theta}}(x)over¯ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_t , over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ( italic_x ).

Define a kernel Kρ0(x,y)=𝔼(w,u)ρ0[σ(wTx+uTe(t))σ(wTy+uTe(t))]subscript𝐾subscript𝜌0𝑥𝑦subscript𝔼similar-to𝑤𝑢subscript𝜌0delimited-[]𝜎superscript𝑤𝑇𝑥superscript𝑢𝑇𝑒𝑡𝜎superscript𝑤𝑇𝑦superscript𝑢𝑇𝑒𝑡K_{\rho_{0}}(x,y)=\mathbb{E}_{(w,u)\sim\rho_{0}}\left[\sigma\left(w^{T}x+u^{T}% e(t)\right)\sigma\left(w^{T}y+u^{T}e(t)\right)\right]italic_K start_POSTSUBSCRIPT italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_y ) = blackboard_E start_POSTSUBSCRIPT ( italic_w , italic_u ) ∼ italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_σ ( italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + italic_u start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_e ( italic_t ) ) italic_σ ( italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_y + italic_u start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_e ( italic_t ) ) ] and denote the induced reproducing kernel Hilbert space (RKHS) as Kρ0subscriptsubscript𝐾subscript𝜌0\mathcal{H}_{K_{\rho_{0}}}caligraphic_H start_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT; if there is no misunderstanding, we denote :=Kρ0assignsubscriptsubscript𝐾subscript𝜌0\mathcal{H}:=\mathcal{H}_{K_{\rho_{0}}}caligraphic_H := caligraphic_H start_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT. It follows that s¯t,θ¯subscript¯𝑠𝑡¯𝜃\bar{s}_{t,\bar{\theta}}\in\mathcal{H}over¯ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_t , over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ∈ caligraphic_H if and only if s¯t,θ¯=𝔼(w,u)ρ0[a(w,u)22]<subscriptdelimited-∥∥subscript¯𝑠𝑡¯𝜃subscript𝔼similar-to𝑤𝑢subscript𝜌0delimited-[]superscriptsubscriptdelimited-∥∥𝑎𝑤𝑢22\left\lVert\bar{s}_{t,\bar{\theta}}\right\rVert_{\mathcal{H}}=\mathbb{E}_{(w,u% )\sim\rho_{0}}\left[\left\lVert a(w,u)\right\rVert_{2}^{2}\right]<\infty∥ over¯ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_t , over¯ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT caligraphic_H end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT ( italic_w , italic_u ) ∼ italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_a ( italic_w , italic_u ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] < ∞.

In Hsu et al. [25], a notion of approximation quality called minimum width of the neural network is defined to measure the minimum number of random features needed to guarantee an accurate enough approximation with high probability. The exact definition is given below.

Definition 1.

Given ϵ,δ>0italic-ϵ𝛿0\epsilon,\delta>0italic_ϵ , italic_δ > 0 and a function f:d:𝑓superscript𝑑f:\mathbb{R}^{d}\to\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R with bounded norm fα<subscriptdelimited-∥∥𝑓𝛼\left\lVert f\right\rVert_{\alpha}<\infty∥ italic_f ∥ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT < ∞, where α𝛼\alphaitalic_α is the measure in dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT associated with the corresponding function space. We also denote g(i)(x)=σ(wTx+uTe(t))superscript𝑔𝑖𝑥𝜎superscript𝑤𝑇𝑥superscript𝑢𝑇𝑒𝑡g^{(i)}(x)=\sigma\left(w^{T}x+u^{T}e(t)\right)italic_g start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( italic_x ) = italic_σ ( italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x + italic_u start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_e ( italic_t ) ). The minimum width mf,ϵ,δ,α,ρ0subscript𝑚𝑓italic-ϵ𝛿𝛼subscript𝜌0m_{f,\epsilon,\delta,\alpha,\rho_{0}}italic_m start_POSTSUBSCRIPT italic_f , italic_ϵ , italic_δ , italic_α , italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT is defined to be the smallest r+𝑟superscriptr\in\mathbb{Z}^{+}italic_r ∈ blackboard_Z start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT such that with probability at least 1δ1𝛿1-\delta1 - italic_δ over g(1),,g(r)superscript𝑔1superscript𝑔𝑟g^{(1)},\ldots,g^{(r)}italic_g start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , italic_g start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT,

infgspan(g(1),,g(r))fgα<ϵ.subscriptinfimum𝑔spansuperscript𝑔1superscript𝑔𝑟subscriptdelimited-∥∥𝑓𝑔𝛼italic-ϵ\inf_{g\in\text{span}\left(g^{(1)},\ldots,g^{(r)}\right)}\left\lVert f-g\right% \rVert_{\alpha}<\epsilon.roman_inf start_POSTSUBSCRIPT italic_g ∈ span ( italic_g start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , italic_g start_POSTSUPERSCRIPT ( italic_r ) end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ∥ italic_f - italic_g ∥ start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT < italic_ϵ .

Moreover, for s0𝑠0s\geq 0italic_s ≥ 0, p[1,]𝑝1p\in[1,\infty]italic_p ∈ [ 1 , ∞ ], and Ud𝑈superscript𝑑U\subset\mathbb{R}^{d}italic_U ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT be an open and bounded set, Ws,p(U)superscript𝑊𝑠𝑝𝑈W^{s,p}(U)italic_W start_POSTSUPERSCRIPT italic_s , italic_p end_POSTSUPERSCRIPT ( italic_U ) is the Sobolev space with order s,p𝑠𝑝s,pitalic_s , italic_p consists of all locally integrable function f𝑓fitalic_f such that for each multiindex α𝛼\alphaitalic_α with |α|s𝛼𝑠|\alpha|\leq s| italic_α | ≤ italic_s, weak derivative of f𝑓fitalic_f exists and has finite Lpsuperscript𝐿𝑝L^{p}italic_L start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT norm (see Evans [19]). If p=2𝑝2p=2italic_p = 2, we denote Ws,2(U)=Hs(U)superscript𝑊𝑠2𝑈superscript𝐻𝑠𝑈W^{s,2}(U)=H^{s}(U)italic_W start_POSTSUPERSCRIPT italic_s , 2 end_POSTSUPERSCRIPT ( italic_U ) = italic_H start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_U ) to reflect the fact that it is a Hilbert space now. Finally, recall that the space of all Lipshitz functions on U𝑈Uitalic_U is the same as W1,(U)superscript𝑊1𝑈W^{1,\infty}(U)italic_W start_POSTSUPERSCRIPT 1 , ∞ end_POSTSUPERSCRIPT ( italic_U ).

With these settings and definitions, we can state and prove the following generalization error for the basic diffusion model using random feature model.

Theorem 5.

Suppose that the target distribution μ𝜇\muitalic_μ is continuously differentiable and has a compact support, we choose an appropriate random feature ρ0subscript𝜌0\rho_{0}italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, and there exists a RKHS \mathcal{H}caligraphic_H such that s¯0,θ¯subscript¯𝑠0¯superscript𝜃\bar{s}_{0,\bar{\theta^{*}}}\in\mathcal{H}over¯ start_ARG italic_s end_ARG start_POSTSUBSCRIPT 0 , over¯ start_ARG italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG end_POSTSUBSCRIPT ∈ caligraphic_H. Assume that the initial loss, trainable parameters, the embedding function e(t)𝑒𝑡e(t)italic_e ( italic_t ) and the weighting function γ(t)𝛾𝑡\gamma(t)italic_γ ( italic_t ) are all bounded. We further suppose that for all t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ], the score function logptHs(K)W1,(K)subscript𝑝𝑡superscript𝐻𝑠𝐾superscript𝑊1𝐾\nabla\log p_{t}\in H^{s}(K)\cap W^{1,\infty}(K)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ italic_H start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_K ) ∩ italic_W start_POSTSUPERSCRIPT 1 , ∞ end_POSTSUPERSCRIPT ( italic_K ) and there exists γ>0𝛾0\gamma>0italic_γ > 0 such that logptHs(K)γsubscriptdelimited-∥∥subscript𝑝𝑡superscript𝐻𝑠𝐾𝛾\left\lVert\nabla\log p_{t}\right\rVert_{H^{s}(K)}\leq\gamma∥ ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_H start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_K ) end_POSTSUBSCRIPT ≤ italic_γ, where Kd𝐾superscript𝑑K\subset\mathbb{R}^{d}italic_K ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is compact. Then for fixed 0<ϵ,δ1formulae-sequence0italic-ϵmuch-less-than𝛿10<\epsilon,\delta\ll 10 < italic_ϵ , italic_δ ≪ 1, with probability at least 1δ1𝛿1-\delta1 - italic_δ, we have

DKL(μ||μ^)\displaystyle D_{\text{KL}}\left(\mu||\hat{\mu}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ | | over^ start_ARG italic_μ end_ARG ) (τ4m3n+τ2mn+τ3m2+1τ+1m)less-than-or-similar-toabsentsuperscript𝜏4superscript𝑚3𝑛superscript𝜏2𝑚𝑛superscript𝜏3superscript𝑚21𝜏1𝑚\displaystyle\lesssim\left(\frac{\tau^{4}}{m^{3}n}+\frac{\tau^{2}}{mn}+\frac{% \tau^{3}}{m^{2}}+\frac{1}{\tau}+\frac{1}{m}\right)≲ ( divide start_ARG italic_τ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG start_ARG italic_m start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_n end_ARG + divide start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_m italic_n end_ARG + divide start_ARG italic_τ start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG italic_m start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + divide start_ARG 1 end_ARG start_ARG italic_τ end_ARG + divide start_ARG 1 end_ARG start_ARG italic_m end_ARG )
+min((slogm)s/2,(d(m1/d2)sγ2/s)s/2)+DKL(pT||π),\displaystyle+\min\left(\left(\frac{s}{\log m}\right)^{s/2},\left(\frac{d\left% (m^{1/d}-2\right)}{s\gamma^{2/s}}\right)^{-s/2}\right)+D_{\text{KL}}\left(p_{T% }||\pi\right),+ roman_min ( ( divide start_ARG italic_s end_ARG start_ARG roman_log italic_m end_ARG ) start_POSTSUPERSCRIPT italic_s / 2 end_POSTSUPERSCRIPT , ( divide start_ARG italic_d ( italic_m start_POSTSUPERSCRIPT 1 / italic_d end_POSTSUPERSCRIPT - 2 ) end_ARG start_ARG italic_s italic_γ start_POSTSUPERSCRIPT 2 / italic_s end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT - italic_s / 2 end_POSTSUPERSCRIPT ) + italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT | | italic_π ) ,

where τ𝜏\tauitalic_τ is the training time (steps) in the gradient flow dynamics (see Li et al. [29]), m𝑚mitalic_m is the number of random features, n𝑛nitalic_n is the sample size of the target distribution, π𝜋\piitalic_π is the stationary Gaussian distribution, pTsubscript𝑝𝑇p_{T}italic_p start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is the distribution of the forward OU process at time T𝑇Titalic_T, μ𝜇\muitalic_μ is the target distribution, and μ^^𝜇\hat{\mu}over^ start_ARG italic_μ end_ARG is the distribution of the generated samples.

Proof.

The proof follows exactly the same as in the proof of Theorem 1 in Li et al. [29]. The only extra work is to compute the universal approximation error of the random feature model for Sobolev functions on a compact domain. From compacted supported assumption (Lemma 1 in Li et al. [29]), the forward process defines a random path (X(t),t)t[0,T]subscript𝑋𝑡𝑡𝑡0𝑇\left(X(t),t\right)_{t\in[0,T]}( italic_X ( italic_t ) , italic_t ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT contained in a compact rectangular domain in d+1superscript𝑑1\mathbb{R}^{d+1}blackboard_R start_POSTSUPERSCRIPT italic_d + 1 end_POSTSUPERSCRIPT.

Theorem 35 in Hsu et al. [25] states the existence of a random feature ρ0subscript𝜌0\rho_{0}italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT such that for any fHs(K)𝑓superscript𝐻𝑠𝐾f\in H^{s}(K)italic_f ∈ italic_H start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_K ) with fHs(K)γsubscriptdelimited-∥∥𝑓superscript𝐻𝑠𝐾𝛾\left\lVert f\right\rVert_{H^{s}(K)}\leq\gamma∥ italic_f ∥ start_POSTSUBSCRIPT italic_H start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_K ) end_POSTSUBSCRIPT ≤ italic_γ, mf,ϵ,δ,α,ρ0s2γ2+4/sd2ϵ2+4/slog(1δ)exp(min(dlog(γ2ϵ2d+2),γ2ϵ2log(dϵ2γ2+2)))less-than-or-similar-tosubscript𝑚𝑓italic-ϵ𝛿𝛼subscript𝜌0superscript𝑠2superscript𝛾24𝑠superscript𝑑2superscriptitalic-ϵ24𝑠1𝛿𝑑superscript𝛾2superscriptitalic-ϵ2𝑑2superscript𝛾2superscriptitalic-ϵ2𝑑superscriptitalic-ϵ2superscript𝛾22m_{f,\epsilon,\delta,\alpha,\rho_{0}}\lesssim\frac{s^{2}\gamma^{2+4/s}d^{2}}{% \epsilon^{2+4/s}}\log\left(\frac{1}{\delta}\right)\exp\left(\min\left(d\log% \left(\frac{\gamma^{2}}{\epsilon^{2}d}+2\right),\frac{\gamma^{2}}{\epsilon^{2}% }\log\left(\frac{d\epsilon^{2}}{\gamma^{2}}+2\right)\right)\right)italic_m start_POSTSUBSCRIPT italic_f , italic_ϵ , italic_δ , italic_α , italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≲ divide start_ARG italic_s start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT 2 + 4 / italic_s end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 + 4 / italic_s end_POSTSUPERSCRIPT end_ARG roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) roman_exp ( roman_min ( italic_d roman_log ( divide start_ARG italic_γ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d end_ARG + 2 ) , divide start_ARG italic_γ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG roman_log ( divide start_ARG italic_d italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_γ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + 2 ) ) ), which implies the approximation error term.    

Remark 1.

The random feature model has two difficulties in implementation.

If m𝑚mitalic_m, T𝑇Titalic_T, and τ𝜏\tauitalic_τ are large enough, then the generalization error is small regardless to the sample size n𝑛nitalic_n. However, the choice of random feature ρ0subscript𝜌0\rho_{0}italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is hard in practice, especially in neither Hsu et al. [25] nor Li et al. [29] the method to choose ρ0subscript𝜌0\rho_{0}italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is specified. Therefore, the assumption that ρ0subscript𝜌0\rho_{0}italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is appropriately chosen is very strong.

Even if ρ0subscript𝜌0\rho_{0}italic_ρ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is appropriately chosen, if we let mnsimilar-to𝑚𝑛m\sim nitalic_m ∼ italic_n and try to find an optimal early stop** time as in Li et al. [29], the term min((slogn)s/2,(d(n1/d2)sγ2/s)s/2)superscript𝑠𝑛𝑠2superscript𝑑superscript𝑛1𝑑2𝑠superscript𝛾2𝑠𝑠2\min\left(\left(\frac{s}{\log n}\right)^{s/2},\left(\frac{d\left(n^{1/d}-2% \right)}{s\gamma^{2/s}}\right)^{-s/2}\right)roman_min ( ( divide start_ARG italic_s end_ARG start_ARG roman_log italic_n end_ARG ) start_POSTSUPERSCRIPT italic_s / 2 end_POSTSUPERSCRIPT , ( divide start_ARG italic_d ( italic_n start_POSTSUPERSCRIPT 1 / italic_d end_POSTSUPERSCRIPT - 2 ) end_ARG start_ARG italic_s italic_γ start_POSTSUPERSCRIPT 2 / italic_s end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT - italic_s / 2 end_POSTSUPERSCRIPT ) still dominates and shows the curse of dimensionality.

Appendix B Proof of results in Section 3.1

Before the proofs, we note the strict convexity of the KL barycenter problems via a simple lemma.

Lemma 1.

For any Polish space S𝑆Sitalic_S, the KL barycenter problem minμ𝒫(S)i=1kλiDKL(μPi)s.t.i=1kλi=1subscript𝜇𝒫𝑆superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝐷KLconditional𝜇subscript𝑃𝑖s.t.superscriptsubscript𝑖1𝑘subscript𝜆𝑖1\min_{\mu\in\mathcal{P}(S)}\sum_{i=1}^{k}\lambda_{i}D_{\text{KL}}\left(\mu% \parallel P_{i}\right)\text{s.t.}\sum_{i=1}^{k}\lambda_{i}=1roman_min start_POSTSUBSCRIPT italic_μ ∈ caligraphic_P ( italic_S ) end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) s.t. ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 is strictly convex.

Proof.

Let t(0,1)𝑡01t\in(0,1)italic_t ∈ ( 0 , 1 ) and μ1,μ2Ssubscript𝜇1subscript𝜇2𝑆\mu_{1},\mu_{2}\in Sitalic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ italic_S such that μ1Pimuch-less-thansubscript𝜇1subscript𝑃𝑖\mu_{1}\ll P_{i}italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≪ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and μ2Pimuch-less-thansubscript𝜇2subscript𝑃𝑖\mu_{2}\ll P_{i}italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≪ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, for each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k, then

i=1kλiDKL(tμ1+(1t)μ2Pi)superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝐷KL𝑡subscript𝜇1conditional1𝑡subscript𝜇2subscript𝑃𝑖\displaystyle\sum_{i=1}^{k}\lambda_{i}D_{\text{KL}}\left(t\mu_{1}+(1-t)\mu_{2}% \parallel P_{i}\right)∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_t italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + ( 1 - italic_t ) italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) <i=1kλi[tDKL(μ1Pi)+(1t)DKL(μ2Pi)]absentsuperscriptsubscript𝑖1𝑘subscript𝜆𝑖delimited-[]𝑡subscript𝐷KLconditionalsubscript𝜇1subscript𝑃𝑖1𝑡subscript𝐷KLconditionalsubscript𝜇2subscript𝑃𝑖\displaystyle<\sum_{i=1}^{k}\lambda_{i}\left[tD_{\text{KL}}\left(\mu_{1}% \parallel P_{i}\right)+(1-t)D_{\text{KL}}\left(\mu_{2}\parallel P_{i}\right)\right]< ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_t italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - italic_t ) italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ]
=ti=1kλiDKL(μ1Pi)+(1t)i=1kλiDKL(μ2Pi),absent𝑡superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝐷KLconditionalsubscript𝜇1subscript𝑃𝑖1𝑡superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝐷KLconditionalsubscript𝜇2subscript𝑃𝑖\displaystyle=t\sum_{i=1}^{k}\lambda_{i}D_{\text{KL}}\left(\mu_{1}\parallel P_% {i}\right)+(1-t)\sum_{i=1}^{k}\lambda_{i}D_{\text{KL}}\left(\mu_{2}\parallel P% _{i}\right),= italic_t ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - italic_t ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,

where the inequality follows from the strictly convexity of KL divergence in terms of μ𝜇\muitalic_μ with fixed Pisubscript𝑃𝑖P_{i}italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Therefore, the KL barycenter problem is strictly convex.    

B.1 Proof of Theorem 1

Proof.

It suffices to consider a probability measure μ𝒫(d)𝜇𝒫superscript𝑑\mu\in\mathcal{P}(\mathbb{R}^{d})italic_μ ∈ caligraphic_P ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) with absolutely continuous density q(x)𝑞𝑥q(x)italic_q ( italic_x ) (otherwise the KL divergence is \infty) and show the existence. If there is no confusion, we use the density and measure interchangeably. We denote 𝒫ac(d)subscript𝒫acsuperscript𝑑\mathcal{P}_{\text{ac}}(\mathbb{R}^{d})caligraphic_P start_POSTSUBSCRIPT ac end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) as the space of all absolutely continuous distributions and define a functional F:𝒫ac(d):𝐹subscript𝒫acsuperscript𝑑F:\mathcal{P}_{\text{ac}}(\mathbb{R}^{d})\to\mathbb{R}italic_F : caligraphic_P start_POSTSUBSCRIPT ac end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) → blackboard_R that for xd,𝑥superscript𝑑x\in\mathbb{R}^{d},italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ,

F(q,x)=i=1kλiq(x)log(q(x)pi(x)).𝐹𝑞𝑥superscriptsubscript𝑖1𝑘subscript𝜆𝑖𝑞𝑥𝑞𝑥subscript𝑝𝑖𝑥F(q,x)=\sum_{i=1}^{k}\lambda_{i}q(x)\log\left(\frac{q(x)}{p_{i}(x)}\right).italic_F ( italic_q , italic_x ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_q ( italic_x ) roman_log ( divide start_ARG italic_q ( italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) end_ARG ) .

Therefore, the barycenter problem becomes

minμ𝒫ac(d)xdF(q,x)𝑑xs.t.i=1kλi=1 and xdq(x)𝑑x=1,subscript𝜇subscript𝒫acsuperscript𝑑subscript𝑥superscript𝑑𝐹𝑞𝑥differential-d𝑥s.t.superscriptsubscript𝑖1𝑘subscript𝜆𝑖1 and subscript𝑥superscript𝑑𝑞𝑥differential-d𝑥1\min_{\mu\in\mathcal{P}_{\text{ac}}(\mathbb{R}^{d})}\int_{x\in\mathbb{R}^{d}}F% (q,x)dx\quad\text{s.t.}\sum_{i=1}^{k}\lambda_{i}=1\text{ and }\int_{x\in% \mathbb{R}^{d}}q(x)dx=1,roman_min start_POSTSUBSCRIPT italic_μ ∈ caligraphic_P start_POSTSUBSCRIPT ac end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_F ( italic_q , italic_x ) italic_d italic_x s.t. ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 and ∫ start_POSTSUBSCRIPT italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_q ( italic_x ) italic_d italic_x = 1 ,

which is a variational problem with a subsidiary condition ([21]). Therefore, from calculus of variations, a necessary condition for q𝑞qitalic_q to be an extremal of the variational problem is for some constant m𝑚mitalic_m

qF(q)+m=0.𝑞𝐹𝑞𝑚0\frac{\partial}{\partial q}F(q)+m=0.divide start_ARG ∂ end_ARG start_ARG ∂ italic_q end_ARG italic_F ( italic_q ) + italic_m = 0 .

Hence, the optimal solution is

q(x)=i=1kpi(x)λii=1kpi(x)λidx.superscript𝑞𝑥superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖𝑑𝑥q^{*}(x)=\frac{\prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}}{\int\prod_{i=1}^{k}p_{i}% (x)^{\lambda_{i}}dx}.italic_q start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) = divide start_ARG ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∫ ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_x end_ARG .

   

B.2 Proof of Theorem 2

Before the proof of Theorem 2, we review a consequence of Girsanov’s theorem (Theorem 8 in Chen et al. [12]). We will use a similar technique as in Chen et al. [12]) to prove heorem 2.

Theorem 6.

Suppose Q𝒫(C([0,T]:d))Q\in\mathcal{P}(C([0,T]:\mathbb{R}^{d}))italic_Q ∈ caligraphic_P ( italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) ). For t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ], let (t)=0tb(s)𝑑B(s)𝑡superscriptsubscript0𝑡𝑏𝑠differential-d𝐵𝑠\mathcal{L}(t)=\int_{0}^{t}b(s)dB(s)caligraphic_L ( italic_t ) = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_b ( italic_s ) italic_d italic_B ( italic_s ) and the stochastic exponential ()(t)=exp(0tb(s)𝑑B(s)120tb(s)22𝑑s)𝑡superscriptsubscript0𝑡𝑏𝑠differential-d𝐵𝑠12superscriptsubscript0𝑡superscriptsubscriptnorm𝑏𝑠22differential-d𝑠\mathcal{E}\left(\mathcal{L}\right)(t)=\exp\left(\int_{0}^{t}b(s)dB(s)-\frac{1% }{2}\int_{0}^{t}\left\|b(s)\right\|_{2}^{2}ds\right)caligraphic_E ( caligraphic_L ) ( italic_t ) = roman_exp ( ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_b ( italic_s ) italic_d italic_B ( italic_s ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ), where B𝐵Bitalic_B is a Q𝑄Qitalic_Q-Brownian motion. Assume 𝔼Q[0Tb(s)22𝑑s]<subscript𝔼𝑄delimited-[]superscriptsubscript0𝑇superscriptsubscriptnorm𝑏𝑠22differential-d𝑠\mathbb{E}_{Q}\left[\int_{0}^{T}\left\|b(s)\right\|_{2}^{2}ds\right]<\inftyblackboard_E start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ] < ∞. Then \mathcal{L}caligraphic_L is a square integrable Q𝑄Qitalic_Q-martingale. Moreover, if 𝔼Q[()(T)]=1,subscript𝔼𝑄delimited-[]𝑇1\mathbb{E}_{Q}\left[\mathcal{E}\left(\mathcal{L}\right)(T)\right]=1,blackboard_E start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ caligraphic_E ( caligraphic_L ) ( italic_T ) ] = 1 , then ()\mathcal{E}\left(\mathcal{L}\right)caligraphic_E ( caligraphic_L ) is a true Q𝑄Qitalic_Q-martingale and the process B(t)0tb(s)𝑑s𝐵𝑡superscriptsubscript0𝑡𝑏𝑠differential-d𝑠B(t)-\int_{0}^{t}b(s)dsitalic_B ( italic_t ) - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_b ( italic_s ) italic_d italic_s is a P𝑃Pitalic_P-Brownian motion, where P𝑃Pitalic_P is a probabilty measure such that P=()(T)Q𝑃𝑇𝑄P=\mathcal{E}\left(\mathcal{L}\right)(T)Qitalic_P = caligraphic_E ( caligraphic_L ) ( italic_T ) italic_Q.

In most applications of Girsanov’s theorem, we need to check a sufficient condition to hold, known as Novikov’s condition. In the context of Theorem 6, Novikov’s condition is

𝔼Q[exp(120Tb(s)22𝑑s)]<.subscript𝔼𝑄delimited-[]12superscriptsubscript0𝑇superscriptsubscriptnorm𝑏𝑠22differential-d𝑠\mathbb{E}_{Q}\left[\exp\left(\frac{1}{2}\int_{0}^{T}\left\|b(s)\right\|_{2}^{% 2}ds\right)\right]<\infty.blackboard_E start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ) ] < ∞ . (13)

Now we begin the proof of Theorem 2.

Proof.

From Lemma 1, it suffices to show the existence. Let α𝒫(C([0,T]:d)\alpha\in\mathcal{P}(C([0,T]:\mathbb{R}^{d})italic_α ∈ caligraphic_P ( italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) with initial distribution α0subscript𝛼0\alpha_{0}italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We denote α(0)𝛼0\alpha(0)italic_α ( 0 ) as the initial distribution of the process whose law is measure α𝛼\alphaitalic_α as notation. From the chain rule of KL divergence, we have

i=1kλiDKL(αPi)superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝐷KLconditional𝛼subscript𝑃𝑖\displaystyle\sum_{i=1}^{k}\lambda_{i}D_{\text{KL}}\left(\alpha\parallel P_{i}\right)∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) =i=1kλiDKL(α0μi)absentsuperscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝐷KLconditionalsubscript𝛼0subscript𝜇𝑖\displaystyle=\sum_{i=1}^{k}\lambda_{i}D_{\text{KL}}\left(\alpha_{0}\parallel% \mu_{i}\right)= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
+𝔼zα0[i=1kλiDKL(α(.|α(0)=z)Pi(.|Pi(0)=z))],\displaystyle+\mathbb{E}_{z\sim\alpha_{0}}\left[\sum_{i=1}^{k}\lambda_{i}D_{% \text{KL}}\left(\alpha\left(.|\alpha(0)=z\right)\parallel P_{i}\left(.|P_{i}(0% )=z\right)\right)\right],+ blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ( . | italic_α ( 0 ) = italic_z ) ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( . | italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 ) = italic_z ) ) ] ,

where the first term solves the KL barycenter problem with respect to the initial distributions, and the second term solves the KL barycenter problem with all reference processes have the same initial distribution. Therefore, to finish the proof, we can assume for each i=1,,k𝑖1𝑘i=1,\ldots,kitalic_i = 1 , … , italic_k, μiμsimilar-tosubscript𝜇𝑖𝜇\mu_{i}\sim\muitalic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_μ, the same initial distribution.

Since we are finding the minimizer of the weight sum of KL divergences, it is sufficient to assume that α𝛼\alphaitalic_α is the law of a diffusion process which is a strong solution of an SDE with the same diffusion (volatility) coefficient as all reference processes:

dX(t)=[c(t,X(t))+σ(t)2a(t,X(t))]dt+σ(t)dB(t),X(0)μ,formulae-sequence𝑑𝑋𝑡delimited-[]𝑐𝑡𝑋𝑡𝜎superscript𝑡2𝑎𝑡𝑋𝑡𝑑𝑡𝜎𝑡𝑑𝐵𝑡similar-to𝑋0𝜇dX(t)=\left[c\left(t,X(t)\right)+\sigma(t)^{2}a\left(t,X(t)\right)\right]dt+% \sigma(t)dB(t),X(0)\sim\mu,italic_d italic_X ( italic_t ) = [ italic_c ( italic_t , italic_X ( italic_t ) ) + italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a ( italic_t , italic_X ( italic_t ) ) ] italic_d italic_t + italic_σ ( italic_t ) italic_d italic_B ( italic_t ) , italic_X ( 0 ) ∼ italic_μ ,

where B𝐵Bitalic_B is a standard Brownian motion, and otherwise the KL divergence would be \infty. For now, we assume that a(t,x)𝑎𝑡𝑥a(t,x)italic_a ( italic_t , italic_x ) is uniformly bounded.

When applying Girsanov’s theorem, it is more convenient to view different path measures on 𝒫(C([0,T]:d)\mathcal{P}(C([0,T]:\mathbb{R}^{d})caligraphic_P ( italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) as the different laws of the same single stochastic process. For notational convenience, we denote the single process as {Z(t)}t[0,T]subscript𝑍𝑡𝑡0𝑇\{Z(t)\}_{t\in[0,T]}{ italic_Z ( italic_t ) } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT.

For each i=1,,k𝑖1𝑘i=1,\ldots,kitalic_i = 1 , … , italic_k, we can apply the Girsanov’s theorem to Q=α𝑄𝛼Q=\alphaitalic_Q = italic_α and

b(t)=σ(t)(ai(t,Z(t))a(t,Z(t)))𝑏𝑡𝜎𝑡subscript𝑎𝑖𝑡𝑍𝑡𝑎𝑡𝑍𝑡b(t)=\sigma(t)\left(a_{i}(t,Z(t))-a(t,Z(t))\right)italic_b ( italic_t ) = italic_σ ( italic_t ) ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) - italic_a ( italic_t , italic_Z ( italic_t ) ) )

in the setting of Theorem 6. Therefore, under the measure P=()(T)α𝑃𝑇𝛼P=\mathcal{E}\left(\mathcal{L}\right)(T)\alphaitalic_P = caligraphic_E ( caligraphic_L ) ( italic_T ) italic_α, there exists a Brownian motion {β(t)}t[0,T]subscript𝛽𝑡𝑡0𝑇\{\beta(t)\}_{t\in[0,T]}{ italic_β ( italic_t ) } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT such that

dB(t)=σ(t)(ai(t,Z(t))a(t,Z(t)))dt+dβ(t).𝑑𝐵𝑡𝜎𝑡subscript𝑎𝑖𝑡𝑍𝑡𝑎𝑡𝑍𝑡𝑑𝑡𝑑𝛽𝑡dB(t)=\sigma(t)\left(a_{i}(t,Z(t))-a(t,Z(t))\right)dt+d\beta(t).italic_d italic_B ( italic_t ) = italic_σ ( italic_t ) ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) - italic_a ( italic_t , italic_Z ( italic_t ) ) ) italic_d italic_t + italic_d italic_β ( italic_t ) .

Since under the measure α𝛼\alphaitalic_α, with probability 1,

dZ(t)=[c(t,Z(t))+σ(t)2a(t,Z(t))]dt+σ(t)dB(t),Z(0)μ,formulae-sequence𝑑𝑍𝑡delimited-[]𝑐𝑡𝑍𝑡𝜎superscript𝑡2𝑎𝑡𝑍𝑡𝑑𝑡𝜎𝑡𝑑𝐵𝑡similar-to𝑍0𝜇dZ(t)=\left[c\left(t,Z(t)\right)+\sigma(t)^{2}a\left(t,Z(t)\right)\right]dt+% \sigma(t)dB(t),Z(0)\sim\mu,italic_d italic_Z ( italic_t ) = [ italic_c ( italic_t , italic_Z ( italic_t ) ) + italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a ( italic_t , italic_Z ( italic_t ) ) ] italic_d italic_t + italic_σ ( italic_t ) italic_d italic_B ( italic_t ) , italic_Z ( 0 ) ∼ italic_μ ,

then this also holds P𝑃Pitalic_P-almost surely, which implies that P𝑃Pitalic_P-almost surely, Z(0)μsimilar-to𝑍0𝜇Z(0)\sim\muitalic_Z ( 0 ) ∼ italic_μ, and

dZ(t)𝑑𝑍𝑡\displaystyle dZ(t)italic_d italic_Z ( italic_t ) =[c(t,Z(t))+σ(t)2a(t,Z(t))]dt+σ(t)[ai(t,Z(t))a(t,Z(t))]dt+σdβ(t)absentdelimited-[]𝑐𝑡𝑍𝑡𝜎superscript𝑡2𝑎𝑡𝑍𝑡𝑑𝑡𝜎𝑡delimited-[]subscript𝑎𝑖𝑡𝑍𝑡𝑎𝑡𝑍𝑡𝑑𝑡𝜎𝑑𝛽𝑡\displaystyle=\left[c\left(t,Z(t)\right)+\sigma(t)^{2}a\left(t,Z(t)\right)% \right]dt+\sigma(t)\left[a_{i}(t,Z(t))-a(t,Z(t))\right]dt+\sigma d\beta(t)= [ italic_c ( italic_t , italic_Z ( italic_t ) ) + italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a ( italic_t , italic_Z ( italic_t ) ) ] italic_d italic_t + italic_σ ( italic_t ) [ italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) - italic_a ( italic_t , italic_Z ( italic_t ) ) ] italic_d italic_t + italic_σ italic_d italic_β ( italic_t )
=[c(t,Z(t))+σ(t)2ai(t,Z(t))]dt+σdβ(t).absentdelimited-[]𝑐𝑡𝑍𝑡𝜎superscript𝑡2subscript𝑎𝑖𝑡𝑍𝑡𝑑𝑡𝜎𝑑𝛽𝑡\displaystyle=\left[c\left(t,Z(t)\right)+\sigma(t)^{2}a_{i}\left(t,Z(t)\right)% \right]dt+\sigma d\beta(t).= [ italic_c ( italic_t , italic_Z ( italic_t ) ) + italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) ] italic_d italic_t + italic_σ italic_d italic_β ( italic_t ) .

In other words, PPisimilar-to𝑃subscript𝑃𝑖P\sim P_{i}italic_P ∼ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in law.

Therefore,

DKL(αPi)subscript𝐷KLconditional𝛼subscript𝑃𝑖\displaystyle D_{\text{KL}}\left(\alpha\parallel P_{i}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) =𝔼α[log(dαdPi)]absentsubscript𝔼𝛼delimited-[]𝑑𝛼𝑑subscript𝑃𝑖\displaystyle=\mathbb{E}_{\alpha}\left[\log\left(\frac{d\alpha}{dP_{i}}\right)\right]= blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ roman_log ( divide start_ARG italic_d italic_α end_ARG start_ARG italic_d italic_P start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ) ]
=𝔼α[log(1()(T))]absentsubscript𝔼𝛼delimited-[]1𝑇\displaystyle=\mathbb{E}_{\alpha}\left[\log\left(\frac{1}{\mathcal{E}\left(% \mathcal{L}\right)(T)}\right)\right]= blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ roman_log ( divide start_ARG 1 end_ARG start_ARG caligraphic_E ( caligraphic_L ) ( italic_T ) end_ARG ) ]
=12𝔼α[0Tσ(t)2ai(t,Z(t))a(t,Z(t)22dt]\displaystyle=\frac{1}{2}\mathbb{E}_{\alpha}\left[\int_{0}^{T}\sigma(t)^{2}% \left\|a_{i}(t,Z(t))-a(t,Z(t)\right\|_{2}^{2}dt\right]= divide start_ARG 1 end_ARG start_ARG 2 end_ARG blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) - italic_a ( italic_t , italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_t ]
+𝔼α[0Tσ(t)2(a(t,Z(t))ai(t,Z(t))dt]\displaystyle+\mathbb{E}_{\alpha}\left[\int_{0}^{T}\sigma(t)^{2}\left(a(t,Z(t)% )-a_{i}(t,Z(t)\right)dt\right]+ blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_a ( italic_t , italic_Z ( italic_t ) ) - italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) italic_d italic_t ]
=12𝔼α[0Tσ(t)2ai(t,Z(t))a(t,Z(t)22dt]\displaystyle=\frac{1}{2}\mathbb{E}_{\alpha}\left[\int_{0}^{T}\sigma(t)^{2}% \left\|a_{i}(t,Z(t))-a(t,Z(t)\right\|_{2}^{2}dt\right]= divide start_ARG 1 end_ARG start_ARG 2 end_ARG blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ) - italic_a ( italic_t , italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_t ]

since Ito integral with regular integrand is a true martingale.

Therefore, the objective function of process level KL barycenter problem becomes

12i=1kλi𝔼α[0Tσ(t)2a(t,Z(t))ai(t,Z(t)22dt],\frac{1}{2}\sum_{i=1}^{k}\lambda_{i}\mathbb{E}_{\alpha}\left[\int_{0}^{T}% \sigma(t)^{2}\left\|a(t,Z(t))-a_{i}(t,Z(t)\right\|_{2}^{2}dt\right],divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_a ( italic_t , italic_Z ( italic_t ) ) - italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_t ] ,

given we assume that all of reference laws have the same initial distribution. Therefore, as a functional optimization problem, the minimizer a(t,x)=i=1kλiai(t,x)superscript𝑎𝑡𝑥superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝑎𝑖𝑡𝑥a^{*}(t,x)=\sum_{i=1}^{k}\lambda_{i}a_{i}(t,x)italic_a start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , italic_x ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , italic_x ), which finishes the proof.    

Appendix C Proof of results in Section 4

C.1 Preliminaries and basic tools

C.1.1 Preliminaries

We include this subsection to present basic definitions and notations used in our proofs.

Definition 2.

S𝑆Sitalic_S is a Polish space equipped with Borel σ𝜎\sigmaitalic_σ-algebra (S)𝑆\mathcal{B}(S)caligraphic_B ( italic_S ), {Pn}n𝒫(S)subscriptsubscript𝑃𝑛𝑛𝒫𝑆\{P_{n}\}_{n\in\mathbb{N}}\subset\mathcal{P}(S){ italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT ⊂ caligraphic_P ( italic_S ) is a set of probability measures, we say Pnsubscript𝑃𝑛P_{n}italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT converges to P𝒫(S)𝑃𝒫𝑆P\in\mathcal{P}(S)italic_P ∈ caligraphic_P ( italic_S ) weakly if and only if for each bounded and continuous function f:S:𝑓𝑆f:S\to\mathbb{R}italic_f : italic_S → blackboard_R, as n𝑛n\to\inftyitalic_n → ∞,

Sf(x)𝑑Pn(x)Sf(x)𝑑P(x).subscript𝑆𝑓𝑥differential-dsubscript𝑃𝑛𝑥subscript𝑆𝑓𝑥differential-d𝑃𝑥\int_{S}f(x)dP_{n}(x)\to\int_{S}f(x)dP(x).∫ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT italic_f ( italic_x ) italic_d italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) → ∫ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT italic_f ( italic_x ) italic_d italic_P ( italic_x ) .
Definition 3.

Given two measurable spaces (X,)𝑋\left(X,\mathcal{F}\right)( italic_X , caligraphic_F ) and (Y,𝒢)𝑌𝒢\left(Y,\mathcal{G}\right)( italic_Y , caligraphic_G ), f:XY:𝑓𝑋𝑌f:X\to Yitalic_f : italic_X → italic_Y is a measurable function, and (X,,μ)𝑋𝜇\left(X,\mathcal{F},\mu\right)( italic_X , caligraphic_F , italic_μ ) is a (positive) measure space. The pushforward of μ𝜇\muitalic_μ is defined to be a measure f#μsubscript𝑓#𝜇f_{\#}\muitalic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_μ such that for any B𝒢𝐵𝒢B\in\mathcal{G}italic_B ∈ caligraphic_G,

f#μ(B)=μ(f1(B)).subscript𝑓#𝜇𝐵𝜇superscript𝑓1𝐵f_{\#}\mu(B)=\mu\left(f^{-1}(B)\right).italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_μ ( italic_B ) = italic_μ ( italic_f start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_B ) ) .
Definition 4.

A differentiable function F:d:𝐹superscript𝑑F:\mathbb{R}^{d}\to\mathbb{R}italic_F : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R is called L𝐿Litalic_L-smooth if for any x,yd𝑥𝑦superscript𝑑x,y\in\mathbb{R}^{d}italic_x , italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT,

|F(x)F(y)F(y)(xy)|L2yx22.𝐹𝑥𝐹𝑦superscript𝐹𝑦𝑥𝑦𝐿2superscriptsubscriptnorm𝑦𝑥22\lvert F(x)-F(y)-F^{\prime}(y)(x-y)\rvert\ \leq\frac{L}{2}\left\|y-x\right\|_{% 2}^{2}.| italic_F ( italic_x ) - italic_F ( italic_y ) - italic_F start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ( italic_x - italic_y ) | ≤ divide start_ARG italic_L end_ARG start_ARG 2 end_ARG ∥ italic_y - italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .
Definition 5.

A stochastic process {Xt}t[0,T]subscriptsubscript𝑋𝑡𝑡0𝑇\{X_{t}\}_{t\in[0,T]}{ italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT is called a local martingale if there exists a sequence of nondecreasing stop** times {Tn}nsubscriptsubscript𝑇𝑛𝑛\{T_{n}\}_{n\in\mathbb{N}}{ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT such that TnTsubscript𝑇𝑛𝑇T_{n}\to Titalic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT → italic_T and {XtTn}t[0,T]subscriptsubscript𝑋𝑡subscript𝑇𝑛𝑡0𝑇\{X_{t\wedge T_{n}}\}_{t\in[0,T]}{ italic_X start_POSTSUBSCRIPT italic_t ∧ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT is a true martingale.

Next we define some notations and stochastic processes that will be used in the following proofs.

Recall the process (6) is a backward SDE with score terms replaced by the estimations. We say for each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k, process X¯isubscript¯𝑋𝑖\bar{X}_{i}over¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the theoretical backward process with exact score terms:

dX¯i(t)=(aX¯i(t)+σ2logpTti(X¯i(t)))dt+σdWi(t),X¯i(0)pTi.formulae-sequence𝑑subscript¯𝑋𝑖𝑡𝑎subscript¯𝑋𝑖𝑡superscript𝜎2subscriptsuperscript𝑝𝑖𝑇𝑡subscript¯𝑋𝑖𝑡𝑑𝑡𝜎𝑑subscript𝑊𝑖𝑡similar-tosubscript¯𝑋𝑖0subscriptsuperscript𝑝𝑖𝑇d\bar{X}_{i}(t)=\left(a\bar{X}_{i}(t)+\sigma^{2}\nabla\log p^{i}_{T-t}\left(% \bar{X}_{i}(t)\right)\right)dt+\sigma dW_{i}(t),\bar{X}_{i}(0)\sim p^{i}_{T}.italic_d over¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) = ( italic_a over¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) , over¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 ) ∼ italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT . (14)

The corresponding forward process is denoted as Xisubscript𝑋𝑖X_{i}italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT:

dXi(t)=aXi(t)dt+σdW(t),Xi(0)piμi.formulae-sequence𝑑subscript𝑋𝑖𝑡𝑎subscript𝑋𝑖𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-tosubscript𝑋𝑖0subscript𝑝𝑖similar-tosubscript𝜇𝑖dX_{i}(t)=-aX_{i}(t)dt+\sigma dW(t),X_{i}(0)\sim p_{i}\sim\mu_{i}.italic_d italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) = - italic_a italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 ) ∼ italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT . (15)

We denote the marginal density of Xi(t)subscript𝑋𝑖𝑡X_{i}(t)italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) as ptisubscriptsuperscript𝑝𝑖𝑡p^{i}_{t}italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT; when t=0𝑡0t=0italic_t = 0, we use the notation piμisimilar-tosubscript𝑝𝑖subscript𝜇𝑖p_{i}\sim\mu_{i}italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Process (8) is a time-discretized SDE to be implemented in practice. It can be viewed as an approximation of the theoretical barycenter process (denoted as Y~~𝑌\tilde{Y}over~ start_ARG italic_Y end_ARG) of the backward SDEs of the form (14):

dY~(t)=(aY~(t)+σ2i=1kλilogpTti(Y~(t)))dt+σdW(t),Y~(0)γTd,formulae-sequence𝑑~𝑌𝑡𝑎~𝑌𝑡superscript𝜎2superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscriptsuperscript𝑝𝑖𝑇𝑡~𝑌𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-to~𝑌0subscriptsuperscript𝛾𝑑𝑇d\tilde{Y}(t)=\left(a\tilde{Y}(t)+\sigma^{2}\sum_{i=1}^{k}\lambda_{i}\nabla% \log p^{i}_{T-t}\left(\tilde{Y}(t)\right)\right)dt+\sigma dW(t),\tilde{Y}(0)% \sim\gamma^{d}_{T},italic_d over~ start_ARG italic_Y end_ARG ( italic_t ) = ( italic_a over~ start_ARG italic_Y end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over~ start_ARG italic_Y end_ARG ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , over~ start_ARG italic_Y end_ARG ( 0 ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , (16)

where γTdsubscriptsuperscript𝛾𝑑𝑇\gamma^{d}_{T}italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is the distribution level KL barycenter at time T𝑇Titalic_T with respect to the reference measures {pT1,,pTk}subscriptsuperscript𝑝1𝑇subscriptsuperscript𝑝𝑘𝑇\{p^{1}_{T},\ldots,p^{k}_{T}\}{ italic_p start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , … , italic_p start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT }. When T𝑇Titalic_T is large, γTdsubscriptsuperscript𝛾𝑑𝑇\gamma^{d}_{T}italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is approximated by π𝜋\piitalic_π in Equation (8). In theory, there is corresponding forward process with respect to process (16):

dY(t)=aX(t)dt+σdW(t),Y(0)Y~(T).formulae-sequence𝑑𝑌𝑡𝑎𝑋𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-to𝑌0~𝑌𝑇dY(t)=-aX(t)dt+\sigma dW(t),Y(0)\sim\tilde{Y}(T).italic_d italic_Y ( italic_t ) = - italic_a italic_X ( italic_t ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_Y ( 0 ) ∼ over~ start_ARG italic_Y end_ARG ( italic_T ) . (17)

For a fixed 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ, we denote p𝝀,tsubscript𝑝𝝀𝑡p_{\boldsymbol{\lambda},t}italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_t end_POSTSUBSCRIPT as the marginal distribution of process (17) at time t𝑡titalic_t; when t=0𝑡0t=0italic_t = 0, we ignore the time subscript.

C.1.2 Basic algorithms

In this section, we recall the Frank-Wolfe method [9], which is used to solve an optimization problem with L𝐿Litalic_L-smooth convex function f:𝒳:𝑓𝒳f:\mathcal{X}\to\mathbb{R}italic_f : caligraphic_X → blackboard_R on a compact domain 𝒳𝒳\mathcal{X}caligraphic_X:

minx𝒳f(x)subscript𝑥𝒳𝑓𝑥\min_{x\in\mathcal{X}}f(x)roman_min start_POSTSUBSCRIPT italic_x ∈ caligraphic_X end_POSTSUBSCRIPT italic_f ( italic_x ) (18)
Algorithm 2 (vanilla) Frank-Wolfe with function-agnostic step size rule [9]
1:Input: Start atom x0𝒳subscript𝑥0𝒳x_{0}\in\mathcal{X}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ caligraphic_X, objective function f𝑓fitalic_f, smoothness L𝐿Litalic_L
2:Output: Iterates x1,,xτ𝒳subscript𝑥1subscript𝑥𝜏𝒳x_{1},\ldots,x_{\tau}\in\mathcal{X}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ∈ caligraphic_X
3:for τ=0 to 𝜏0 to \tau=0\text{ to }\ldotsitalic_τ = 0 to … do
4:     vτargminv𝒳f(xτ),vsubscript𝑣𝜏subscript𝑣𝒳𝑓subscript𝑥𝜏𝑣v_{\tau}\leftarrow\arg\min_{v\in\mathcal{X}}\langle\nabla f(x_{\tau}),v\rangleitalic_v start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ← roman_arg roman_min start_POSTSUBSCRIPT italic_v ∈ caligraphic_X end_POSTSUBSCRIPT ⟨ ∇ italic_f ( italic_x start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) , italic_v ⟩
5:     γτ{1if τ=12τ+3if τ>1subscript𝛾𝜏cases1if 𝜏12𝜏3if 𝜏1\gamma_{\tau}\leftarrow\begin{cases}1&\text{if }\tau=1\\ \frac{2}{\tau+3}&\text{if }\tau>1\end{cases}italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ← { start_ROW start_CELL 1 end_CELL start_CELL if italic_τ = 1 end_CELL end_ROW start_ROW start_CELL divide start_ARG 2 end_ARG start_ARG italic_τ + 3 end_ARG end_CELL start_CELL if italic_τ > 1 end_CELL end_ROW
6:     xτ+1xτ+γτ(vτxτ)subscript𝑥𝜏1subscript𝑥𝜏subscript𝛾𝜏subscript𝑣𝜏subscript𝑥𝜏x_{\tau+1}\leftarrow x_{\tau}+\gamma_{\tau}(v_{\tau}-x_{\tau})italic_x start_POSTSUBSCRIPT italic_τ + 1 end_POSTSUBSCRIPT ← italic_x start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT + italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT )
7:end for

To measure the error of the algorithm, we define for each τ1𝜏1\tau\geq 1italic_τ ≥ 1, the primary gap is

hτ=h(xτ)=f(xτ)f(x),subscript𝜏subscript𝑥𝜏𝑓subscript𝑥𝜏𝑓superscript𝑥h_{\tau}=h(x_{\tau})=f(x_{\tau})-f(x^{*}),italic_h start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT = italic_h ( italic_x start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) = italic_f ( italic_x start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) - italic_f ( italic_x start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ,

where xsuperscript𝑥x^{*}italic_x start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is the minimizer of problem (18).

C.1.3 Basic lemmas

In this subsection, we first list some basic lemmas (Lemma 2 to 5) that serve as essential tools in our proofs. All proofs can be found in [12].

Lemma 2.

Suppose that Assumption 1 and 2 hold. For each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k, let Zi(t)subscript𝑍𝑖𝑡Z_{i}(t)italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) denote the forward auxiliary process (15), then for all t0𝑡0t\geq 0italic_t ≥ 0,

𝔼[Zi(t)22]dM and 𝔼[logpti(Zi(t))22]Ld.𝔼delimited-[]superscriptsubscriptnormsubscript𝑍𝑖𝑡22𝑑𝑀 and 𝔼delimited-[]superscriptsubscriptnormsubscriptsuperscript𝑝𝑖𝑡subscript𝑍𝑖𝑡22𝐿𝑑\mathbb{E}\left[\left\|Z_{i}(t)\right\|_{2}^{2}\right]\leq d\vee M\text{ and }% \mathbb{E}\left[\left\|\nabla\log p^{i}_{t}\left(Z_{i}(t)\right)\right\|_{2}^{% 2}\right]\leq Ld.blackboard_E [ ∥ italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_d ∨ italic_M and blackboard_E [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_L italic_d .
Lemma 3.

Suppose that Assumption 1 holds. For each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k, let Zi(t)subscript𝑍𝑖𝑡Z_{i}(t)italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) denote the forward auxiliary process (15). For 0s<t0𝑠𝑡0\leq s<t0 ≤ italic_s < italic_t, let δ=ts𝛿𝑡𝑠\delta=t-sitalic_δ = italic_t - italic_s. If δ1𝛿1\delta\leq 1italic_δ ≤ 1, then

𝔼[Zi(t)Zi(s)22]δ2M+δd.less-than-or-similar-to𝔼delimited-[]superscriptsubscriptnormsubscript𝑍𝑖𝑡subscript𝑍𝑖𝑠22superscript𝛿2𝑀𝛿𝑑\mathbb{E}\left[\left\|Z_{i}(t)-Z_{i}(s)\right\|_{2}^{2}\right]\lesssim\delta^% {2}M+\delta d.blackboard_E [ ∥ italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) - italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≲ italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M + italic_δ italic_d .
Lemma 4.

Consider a sequence of functions fn:[0,T]d:subscript𝑓𝑛0𝑇superscript𝑑f_{n}:[0,T]\to\mathbb{R}^{d}italic_f start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT : [ 0 , italic_T ] → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and a function f:[0,T]d:𝑓0𝑇superscript𝑑f:[0,T]\to\mathbb{R}^{d}italic_f : [ 0 , italic_T ] → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT such that there exists a nondecreasing sequence {Tn}n[0,T]subscriptsubscript𝑇𝑛𝑛0𝑇\{T_{n}\}_{n\in\mathbb{N}}\subset[0,T]{ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT ⊂ [ 0 , italic_T ] such that TnTsubscript𝑇𝑛𝑇T_{n}\to Titalic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT → italic_T as n𝑛n\to\inftyitalic_n → ∞ and for each tTn𝑡subscript𝑇𝑛t\leq T_{n}italic_t ≤ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, fn(t)=f(t)subscript𝑓𝑛𝑡𝑓𝑡f_{n}(t)=f(t)italic_f start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = italic_f ( italic_t ), then for each ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, fnfsubscript𝑓𝑛𝑓f_{n}\to fitalic_f start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT → italic_f uniformly over [0,Tϵ]0𝑇italic-ϵ[0,T-\epsilon][ 0 , italic_T - italic_ϵ ].

Lemma 5.

f:[0,T]d:𝑓0𝑇superscript𝑑f:[0,T]\to\mathbb{R}^{d}italic_f : [ 0 , italic_T ] → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a continuous function, and fϵ:[0,T]d:subscript𝑓italic-ϵ0𝑇superscript𝑑f_{\epsilon}:[0,T]\to\mathbb{R}^{d}italic_f start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT : [ 0 , italic_T ] → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT such that for each ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, fϵ(t)=f(t(Tϵ))subscript𝑓italic-ϵ𝑡𝑓𝑡𝑇italic-ϵf_{\epsilon}(t)=f\left(t\wedge(T-\epsilon)\right)italic_f start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_t ) = italic_f ( italic_t ∧ ( italic_T - italic_ϵ ) ), then as ϵ0italic-ϵ0\epsilon\to 0italic_ϵ → 0, fϵfsubscript𝑓italic-ϵ𝑓f_{\epsilon}\to fitalic_f start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT → italic_f uniformly over [0,T]0𝑇[0,T][ 0 , italic_T ].

Next, we review and give two results related to the fusion algorithms.

Lemma 6.

For any fixed 𝛌Δk𝛌subscriptΔ𝑘\boldsymbol{\lambda}\in\Delta_{k}bold_italic_λ ∈ roman_Δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, Y~(T)μ𝛌similar-to~𝑌𝑇subscript𝜇𝛌\tilde{Y}(T)\sim\mu_{\boldsymbol{\lambda}}over~ start_ARG italic_Y end_ARG ( italic_T ) ∼ italic_μ start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT, the KL barycenter of {μ1,,μk}subscript𝜇1subscript𝜇𝑘\{\mu_{1},\ldots,\mu_{k}\}{ italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }.

Proof.

In this proof, we use the following notations: suppose x,yd𝑥𝑦superscript𝑑x,y\in\mathbb{R}^{d}italic_x , italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and 0stT0𝑠𝑡𝑇0\leq s\leq t\leq T0 ≤ italic_s ≤ italic_t ≤ italic_T, we denote pi(x,t|y,s)superscript𝑝𝑖𝑥conditional𝑡𝑦𝑠p^{i}(x,t|y,s)italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , italic_s ) as the transition density of the i𝑖iitalic_ith auxiliary process from time s𝑠sitalic_s to t𝑡titalic_t. Similarly, p𝝀(x,t|y,s)superscript𝑝𝝀𝑥conditional𝑡𝑦𝑠p^{\boldsymbol{\lambda}}(x,t|y,s)italic_p start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , italic_s ) as the transition density of the barycenter process from time s𝑠sitalic_s to t𝑡titalic_t.

Let 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ be fixed, then at each time t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ],

log(p𝝀,t(x))=i=1nλilog(pti(x)).subscript𝑝𝝀𝑡𝑥superscriptsubscript𝑖1𝑛subscript𝜆𝑖superscriptsubscript𝑝𝑡𝑖𝑥\displaystyle\nabla\log\left(p_{\boldsymbol{\lambda},t}(x)\right)=\nabla\sum_{% i=1}^{n}\lambda_{i}\log\left(p_{t}^{i}(x)\right).∇ roman_log ( italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_t end_POSTSUBSCRIPT ( italic_x ) ) = ∇ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x ) ) .

Expanding LHS and RHS at the same time, we get

log(p𝝀(x,t|y,0)p𝝀(y)𝑑y)=i=1kλilog(pi(x,t|y,0)pi(y)𝑑y),superscript𝑝𝝀𝑥conditional𝑡𝑦0subscript𝑝𝝀𝑦differential-d𝑦superscriptsubscript𝑖1𝑘subscript𝜆𝑖superscript𝑝𝑖𝑥conditional𝑡𝑦0subscript𝑝𝑖𝑦differential-d𝑦\nabla\log\left(\int p^{\boldsymbol{\lambda}}(x,t|y,0)p_{\boldsymbol{\lambda}}% (y)dy\right)=\nabla\sum_{i=1}^{k}\lambda_{i}\log\left(\int p^{i}(x,t|y,0)p_{i}% (y)dy\right),∇ roman_log ( ∫ italic_p start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , 0 ) italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_y ) italic_d italic_y ) = ∇ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( ∫ italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , 0 ) italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) italic_d italic_y ) ,

Note that as t0𝑡0t\to 0italic_t → 0, pi(x,t|y,0)δ(xy)superscript𝑝𝑖𝑥conditional𝑡𝑦0𝛿𝑥𝑦p^{i}(x,t|y,0)\to\delta(x-y)italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , 0 ) → italic_δ ( italic_x - italic_y ) and p𝝀(x,t|y,0)δ(xy)superscript𝑝𝝀𝑥conditional𝑡𝑦0𝛿𝑥𝑦p^{\boldsymbol{\lambda}}(x,t|y,0)\to\delta(x-y)italic_p start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , 0 ) → italic_δ ( italic_x - italic_y ), where the limit is the delta function. Therefore, from the compactness assumption and dominated convergence theorem,

logp𝝀(x)subscript𝑝𝝀𝑥\displaystyle\nabla\log p_{\boldsymbol{\lambda}}(x)∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_x ) =limt0log(p𝝀(x,t|y,0)p𝝀(y)𝑑y)absentsubscript𝑡0superscript𝑝𝝀𝑥conditional𝑡𝑦0subscript𝑝𝝀𝑦differential-d𝑦\displaystyle=\lim_{t\to 0}\nabla\log\left(\int p^{\boldsymbol{\lambda}}(x,t|y% ,0)p_{\boldsymbol{\lambda}}(y)dy\right)= roman_lim start_POSTSUBSCRIPT italic_t → 0 end_POSTSUBSCRIPT ∇ roman_log ( ∫ italic_p start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , 0 ) italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_y ) italic_d italic_y )
=limt0i=1kλilog(pi(x,t|y,0)pi(y)𝑑y)absentsubscript𝑡0superscriptsubscript𝑖1𝑘subscript𝜆𝑖superscript𝑝𝑖𝑥conditional𝑡𝑦0subscript𝑝𝑖𝑦differential-d𝑦\displaystyle=\displaystyle{\lim_{t\to 0}}\nabla\sum_{i=1}^{k}\lambda_{i}\log% \left(\int p^{i}(x,t|y,0)p_{i}(y)dy\right)= roman_lim start_POSTSUBSCRIPT italic_t → 0 end_POSTSUBSCRIPT ∇ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( ∫ italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x , italic_t | italic_y , 0 ) italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) italic_d italic_y )
=i=1kλilogpi(x).absentsuperscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝑝𝑖𝑥\displaystyle=\nabla\sum_{i=1}^{k}\lambda_{i}\log p_{i}(x).= ∇ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) .

Therefore,

logp𝝀(x)subscript𝑝𝝀𝑥\displaystyle\log p_{\boldsymbol{\lambda}}(x)roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_x ) i=1kλilogpi(x)proportional-toabsentsuperscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝑝𝑖𝑥\displaystyle\propto\sum_{i=1}^{k}\lambda_{i}\log p_{i}(x)∝ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x )
=log(i=1kpi(x)λi)absentsuperscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖\displaystyle=\log\left(\prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}\right)= roman_log ( ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT )
=log(i=1kpi(x)λi).absentsuperscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖\displaystyle=\log\left(\prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}\right).= roman_log ( ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) .

Since p𝝀(x)subscript𝑝𝝀𝑥p_{\boldsymbol{\lambda}}(x)italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_x ) is a density function, then after normalization

p𝝀(x)=i=1kpi(x)λii=1kpi(x)λidx,subscript𝑝𝝀𝑥superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑥subscript𝜆𝑖𝑑𝑥p_{\boldsymbol{\lambda}}(x)=\frac{\prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}}{\int% \prod_{i=1}^{k}p_{i}(x)^{\lambda_{i}}dx},italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ( italic_x ) = divide start_ARG ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∫ ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_x end_ARG ,

which is the solution of KL barycenter problem with reference measures p1,,pksubscript𝑝1subscript𝑝𝑘p_{1},\ldots,p_{k}italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.    

Next we give the proof of Proposition 1.

Proof.

Recall that the objective function for 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ is

F(𝝀)=𝔼ν[logν(X)i=1kλilogpi(X)]+log(i=1kpi(y)λidy).𝐹𝝀subscript𝔼𝜈delimited-[]𝜈𝑋superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscript𝑝𝑖𝑋superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑦subscript𝜆𝑖𝑑𝑦F(\boldsymbol{\lambda})=\mathbb{E}_{\nu}\left[\log\nu(X)-\sum_{i=1}^{k}\lambda% _{i}\log p_{i}(X)\right]+\log\left(\int\prod_{i=1}^{k}p_{i}(y)^{\lambda_{i}}dy% \right).italic_F ( bold_italic_λ ) = blackboard_E start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT [ roman_log italic_ν ( italic_X ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_X ) ] + roman_log ( ∫ ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_y ) . (19)

We note that the first term is linear in 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ, so to show convexity, it is enough to show the second term is convex in 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ. If we denote hi(x)=log(pi(x))subscript𝑖𝑥subscript𝑝𝑖𝑥h_{i}(x)=\log\left(p_{i}(x)\right)italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) = roman_log ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) ) for each i=1,2,,k𝑖12𝑘i=1,2,\ldots,kitalic_i = 1 , 2 , … , italic_k and X𝑋Xitalic_X as the uniform distribution on K𝐾Kitalic_K, then

log(i=1kpi(y)λidy)superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑦subscript𝜆𝑖𝑑𝑦\displaystyle\log\left(\int\prod_{i=1}^{k}p_{i}(y)^{\lambda_{i}}dy\right)roman_log ( ∫ ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_y ) =log(𝕂i=1kpi(y)λidy)absentsubscript𝕂superscriptsubscriptproduct𝑖1𝑘subscript𝑝𝑖superscript𝑦subscript𝜆𝑖𝑑𝑦\displaystyle=\log\left(\int_{\mathbb{K}}\prod_{i=1}^{k}p_{i}(y)^{\lambda_{i}}% dy\right)= roman_log ( ∫ start_POSTSUBSCRIPT blackboard_K end_POSTSUBSCRIPT ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_y )
=log(1|𝕂|𝕂exp(i=1khi(y)λi)𝑑y)+log(|𝕂|)absent1𝕂subscript𝕂superscriptsubscript𝑖1𝑘subscript𝑖𝑦subscript𝜆𝑖differential-d𝑦𝕂\displaystyle=\log\left(\frac{1}{|\mathbb{K}|}\int_{\mathbb{K}}\exp\left(\sum_% {i=1}^{k}h_{i}(y)\lambda_{i}\right)dy\right)+\log\left(|\mathbb{K}|\right)= roman_log ( divide start_ARG 1 end_ARG start_ARG | blackboard_K | end_ARG ∫ start_POSTSUBSCRIPT blackboard_K end_POSTSUBSCRIPT roman_exp ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_d italic_y ) + roman_log ( | blackboard_K | )
=log(𝔼[exp(𝝀TZ)])+log(|𝕂|),absent𝔼delimited-[]superscript𝝀𝑇𝑍𝕂\displaystyle=\log\left(\mathbb{E}\left[\exp\left(\boldsymbol{\lambda}^{T}Z% \right)\right]\right)+\log\left(|\mathbb{K}|\right),= roman_log ( blackboard_E [ roman_exp ( bold_italic_λ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Z ) ] ) + roman_log ( | blackboard_K | ) ,

where Z=(h1(X),,hk(X))𝑍subscript1𝑋subscript𝑘𝑋Z=\left(h_{1}(X),\ldots,h_{k}(X)\right)italic_Z = ( italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_X ) , … , italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_X ) ) and |𝕂|𝕂|\mathbb{K}|| blackboard_K | is the Lebesgue measure of 𝕂𝕂\mathbb{K}blackboard_K. Since log of moment generating function is convex, then second term in Equation (19) is convex in 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ.    

Remark 2.

In theory, the first order condition of the convex optimization problem (9) is

Fλi(𝝀)𝐹subscript𝜆𝑖𝝀\displaystyle\frac{\partial F}{\partial\lambda_{i}}(\boldsymbol{\lambda})divide start_ARG ∂ italic_F end_ARG start_ARG ∂ italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( bold_italic_λ ) =ν(x)hi(x)𝑑x+λilog(l=1kpl(y)λldy)absent𝜈𝑥subscript𝑖𝑥differential-d𝑥subscript𝜆𝑖superscriptsubscriptproduct𝑙1𝑘subscript𝑝𝑙superscript𝑦subscript𝜆𝑙𝑑𝑦\displaystyle=-\int\nu(x)h_{i}(x)dx+\frac{\partial}{\partial\lambda_{i}}\log% \left(\int\prod_{l=1}^{k}p_{l}(y)^{\lambda_{l}}dy\right)= - ∫ italic_ν ( italic_x ) italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) italic_d italic_x + divide start_ARG ∂ end_ARG start_ARG ∂ italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG roman_log ( ∫ ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_y )
=𝔼ν[hi(X)]+l=1kpl(y)λllogpi(y)dyl=1kpl(y)λldyabsentsubscript𝔼𝜈delimited-[]subscript𝑖𝑋superscriptsubscriptproduct𝑙1𝑘subscript𝑝𝑙superscript𝑦subscript𝜆𝑙subscript𝑝𝑖𝑦𝑑𝑦superscriptsubscriptproduct𝑙1𝑘subscript𝑝𝑙superscript𝑦subscript𝜆𝑙𝑑𝑦\displaystyle=-\mathbb{E}_{\nu}\left[h_{i}(X)\right]+\frac{\int\prod_{l=1}^{k}% p_{l}(y)^{\lambda_{l}}\log p_{i}(y)dy}{\int\prod_{l=1}^{k}p_{l}(y)^{\lambda_{l% }}dy}= - blackboard_E start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT [ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_X ) ] + divide start_ARG ∫ ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) italic_d italic_y end_ARG start_ARG ∫ ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_y ) start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d italic_y end_ARG
=𝔼ν[hi(X)]+exp(l=1kλlhl(y))hi(y)𝑑yexp(l=1kλlhl(y))𝑑y.absentsubscript𝔼𝜈delimited-[]subscript𝑖𝑋superscriptsubscript𝑙1𝑘subscript𝜆𝑙subscript𝑙𝑦subscript𝑖𝑦differential-d𝑦superscriptsubscript𝑙1𝑘subscript𝜆𝑙subscript𝑙𝑦differential-d𝑦\displaystyle=-\mathbb{E}_{\nu}\left[h_{i}(X)\right]+\frac{\int\exp\left(\sum_% {l=1}^{k}\lambda_{l}h_{l}(y)\right)h_{i}(y)dy}{\int\exp\left(\sum_{l=1}^{k}% \lambda_{l}h_{l}(y)\right)dy}.= - blackboard_E start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT [ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_X ) ] + divide start_ARG ∫ roman_exp ( ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_y ) ) italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) italic_d italic_y end_ARG start_ARG ∫ roman_exp ( ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_y ) ) italic_d italic_y end_ARG .

In practice, each hisubscript𝑖h_{i}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is replaced by the estimated auxiliary densities, and the second term is computed independent of the target data ν𝜈\nuitalic_ν. However, the implementation is extremely hard since the numerical integration of the second term may have large error and the error is hard to control.

C.2 Proof of Theorem 3

Before the proof of the sample complexity of the whole algorithm, we first prove a lemma about the auxiliary score estimation errors. The proof is adapted from Chen et al. [12].

Lemma 7.

Suppose that Assumption 2 holds, 𝛌𝛌\boldsymbol{\lambda}bold_italic_λ is fixed, and the step size h=T/N𝑇𝑁h=T/Nitalic_h = italic_T / italic_N satisfies h1/Lless-than-or-similar-to1𝐿h\lesssim 1/Litalic_h ≲ 1 / italic_L, where L1𝐿1L\geq 1italic_L ≥ 1. Let p𝛌subscript𝑝𝛌p_{\boldsymbol{\lambda}}italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT and p^𝛌subscript^𝑝𝛌\hat{p}_{\boldsymbol{\lambda}}over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT denote the distribution of process (16) and (8) at time T𝑇Titalic_T, respectively. Then we have

TV(p𝝀,p^𝝀)exp(T)maxi=1,2,,kDKL(pTiπ)+σkT(ϵscore+Ldh+LhM).less-than-or-similar-toTVsubscript𝑝𝝀subscript^𝑝𝝀𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋𝜎𝑘𝑇subscriptitalic-ϵscore𝐿𝑑𝐿𝑀\text{TV}\left(p_{\boldsymbol{\lambda}},\hat{p}_{\boldsymbol{\lambda}}\right)% \lesssim\exp(-T)\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left(p^{i}_{T}% \parallel\pi\right)}+\sigma\sqrt{kT}\left(\epsilon_{\text{score}}+L\sqrt{dh}+% Lh\sqrt{M}\right).TV ( italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ) ≲ roman_exp ( - italic_T ) roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG + italic_σ square-root start_ARG italic_k italic_T end_ARG ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT + italic_L square-root start_ARG italic_d italic_h end_ARG + italic_L italic_h square-root start_ARG italic_M end_ARG ) .
Remark 3.

To interpret the result, suppose maxi=1,2,,kDKL(pTiπ)poly(d)less-than-or-similar-tosubscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋poly𝑑\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left(p^{i}_{T}\parallel\pi\right)}% \lesssim\text{poly}(d)roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG ≲ poly ( italic_d ) and Md𝑀𝑑M\leq ditalic_M ≤ italic_d, then for fixed ϵitalic-ϵ\epsilonitalic_ϵ, if we choose Tlog(maxi=1,2,,kDKL(pTiπ)/ϵ)similar-to𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋italic-ϵT\sim\log\left(\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left(p^{i}_{T}% \parallel\pi\right)}/\epsilon\right)italic_T ∼ roman_log ( roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG / italic_ϵ ) and hϵ2L2σ2kdsimilar-tosuperscriptitalic-ϵ2superscript𝐿2superscript𝜎2𝑘𝑑h\sim\frac{\epsilon^{2}}{L^{2}\sigma^{2}kd}italic_h ∼ divide start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_d end_ARG, and hiding the logarithmic factors, then with NL2σ2kdϵ2similar-to𝑁superscript𝐿2superscript𝜎2𝑘𝑑superscriptitalic-ϵ2N\sim\frac{L^{2}\sigma^{2}kd}{\epsilon^{2}}italic_N ∼ divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_d end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, SEϵ+ϵscoreless-than-or-similar-toSEitalic-ϵsubscriptitalic-ϵscore\text{SE}\lesssim\epsilon+\epsilon_{\text{score}}SE ≲ italic_ϵ + italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT. In particular, if we want to choose the sampling error SEϵless-than-or-similar-toSEitalic-ϵ\text{SE}\lesssim\epsilonSE ≲ italic_ϵ, it suffices to have ϵscoreϵless-than-or-similar-tosubscriptitalic-ϵscoreitalic-ϵ\epsilon_{\text{score}}\lesssim\epsilonitalic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT ≲ italic_ϵ.

Proof.

We denote the laws of process (16) and (8) as α𝛼\alphaitalic_α and βC([0,T]:d)\beta\in C([0,T]:\mathbb{R}^{d})italic_β ∈ italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ), respectively. For simplicity of the proof, we define a fictitious diffusion satisfying the SDE with Y^(0)γTdsimilar-to^𝑌0subscriptsuperscript𝛾𝑑𝑇\hat{Y}(0)\sim\gamma^{d}_{T}over^ start_ARG italic_Y end_ARG ( 0 ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT:

dY^(t)=(aY^(t)+σ2i=1kλisTlh,θi(Y^(lh)))dt+σdWi(t),t[lh,(l+1)h].formulae-sequence𝑑^𝑌𝑡𝑎^𝑌𝑡superscript𝜎2superscriptsubscript𝑖1𝑘subscript𝜆𝑖subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃^𝑌𝑙𝑑𝑡𝜎𝑑subscript𝑊𝑖𝑡𝑡𝑙𝑙1d\hat{Y}(t)=\left(a\hat{Y}(t)+\sigma^{2}\sum_{i=1}^{k}\lambda_{i}s^{i}_{T-lh,% \theta^{*}}\left(\hat{Y}(lh)\right)\right)dt+\sigma dW_{i}(t),t\in[lh,(l+1)h].italic_d over^ start_ARG italic_Y end_ARG ( italic_t ) = ( italic_a over^ start_ARG italic_Y end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_Y end_ARG ( italic_l italic_h ) ) ) italic_d italic_t + italic_σ italic_d italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) , italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ] . (20)

since in practice, it is always convenient to use Gaussian π𝜋\piitalic_π as a prior. We denote law of process (20) as βTC([0,T]:d)\beta_{T}\in C([0,T]:\mathbb{R}^{d})italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ).

We also denote the score estimators of the process (17) as slh,θ𝝀subscriptsuperscript𝑠𝝀𝑙superscript𝜃s^{\boldsymbol{\lambda}}_{lh,\theta^{*}}italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Similarly as before, we consider only one stochastic process Z(t)t[0,T]𝑍subscript𝑡𝑡0𝑇Z(t)_{t\in[0,T]}italic_Z ( italic_t ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT now to use Girsanov’s theorem.

For t[lh,(l+1)h]𝑡𝑙𝑙1t\in[lh,(l+1)h]italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ], we have the discretization error \mathcal{L}caligraphic_L with

\displaystyle\mathcal{L}caligraphic_L =σ2𝔼α[sTlh,θ𝝀(Z(lh))logp𝝀,Tt(Z(t))22]absentsuperscript𝜎2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript𝑝𝝀𝑇𝑡𝑍𝑡22\displaystyle=\sigma^{2}\mathbb{E}_{\alpha}\left[\left\lVert s^{\boldsymbol{% \lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p_{\boldsymbol{% \lambda},T-t}\left(Z(t)\right)\right\rVert_{2}^{2}\right]= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
=σ2𝔼α[i=1kλi[sTlh,θi(Z(lh))logpTti(Z(t))]22]absentsuperscript𝜎2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscript𝜆𝑖delimited-[]subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃𝑍𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22\displaystyle=\sigma^{2}\mathbb{E}_{\alpha}\left[\left\lVert\sum_{i=1}^{k}% \lambda_{i}\left[s^{i}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p^{i}_{T% -t}\left(Z(t)\right)\right]\right\rVert_{2}^{2}\right]= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ] ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
σ2i=1kλi2𝔼α[sTlh,θi(Z(lh))logpTti(Z(t))22]less-than-or-similar-toabsentsuperscript𝜎2superscriptsubscript𝑖1𝑘superscriptsubscript𝜆𝑖2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃𝑍𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22\displaystyle\lesssim\sigma^{2}\sum_{i=1}^{k}\lambda_{i}^{2}\mathbb{E}_{\alpha% }\left[\left\lVert s^{i}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p^{i}_% {T-t}\left(Z(t)\right)\right\rVert_{2}^{2}\right]≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
σ2i=1kλi2𝔼α[sTlh,θi(Z(lh))logpTlhi(Z(lh))22]less-than-or-similar-toabsentsuperscript𝜎2superscriptsubscript𝑖1𝑘superscriptsubscript𝜆𝑖2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃𝑍𝑙subscriptsuperscript𝑝𝑖𝑇𝑙𝑍𝑙22\displaystyle\lesssim\sigma^{2}\sum_{i=1}^{k}\lambda_{i}^{2}\mathbb{E}_{\alpha% }\left[\left\lVert s^{i}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p^{i}_% {T-lh}\left(Z(lh)\right)\right\rVert_{2}^{2}\right]≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+σ2i=1kλi2𝔼α[logpTlhi(Z(lh))logpTti(Z(lh))22]superscript𝜎2superscriptsubscript𝑖1𝑘superscriptsubscript𝜆𝑖2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑙𝑍𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙22\displaystyle+\sigma^{2}\sum_{i=1}^{k}\lambda_{i}^{2}\mathbb{E}_{\alpha}\left[% \left\lVert\nabla\log p^{i}_{T-lh}\left(Z(lh)\right)-\nabla\log p^{i}_{T-t}% \left(Z(lh)\right)\right\rVert_{2}^{2}\right]+ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+σ2i=1kλi2𝔼α[logpTti(Z(lh))logpTti(Z(t))22]superscript𝜎2superscriptsubscript𝑖1𝑘superscriptsubscript𝜆𝑖2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22\displaystyle+\sigma^{2}\sum_{i=1}^{k}\lambda_{i}^{2}\mathbb{E}_{\alpha}\left[% \left\lVert\nabla\log p^{i}_{T-t}\left(Z(lh)\right)-\nabla\log p^{i}_{T-t}% \left(Z(t)\right)\right\rVert_{2}^{2}\right]+ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
kσ2(ϵscore2+𝔼α[log(pTlhipTti)(Z(lh))22]+L2𝔼α[Z(lh)Z(t)22]).less-than-or-similar-toabsent𝑘superscript𝜎2superscriptsubscriptitalic-ϵscore2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙22superscript𝐿2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥𝑍𝑙𝑍𝑡22\displaystyle\lesssim k\sigma^{2}\left(\epsilon_{\text{score}}^{2}+\mathbb{E}_% {\alpha}\left[\left\lVert\nabla\log\left(\frac{p^{i}_{T-lh}}{p^{i}_{T-t}}% \right)\left(Z(lh)\right)\right\rVert_{2}^{2}\right]+L^{2}\mathbb{E}_{\alpha}% \left[\left\lVert Z(lh)-Z(t)\right\rVert_{2}^{2}\right]\right).≲ italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log ( divide start_ARG italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT end_ARG ) ( italic_Z ( italic_l italic_h ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_Z ( italic_l italic_h ) - italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ) .

From Lemma 16 in Chen et al. [12], we have the bound for the second term since L1𝐿1L\geq 1italic_L ≥ 1,

𝔼α[log(pTlhipTti)(Z(lh))22]subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙22\displaystyle\mathbb{E}_{\alpha}\left[\left\lVert\nabla\log\left(\frac{p^{i}_{% T-lh}}{p^{i}_{T-t}}\right)\left(Z(lh)\right)\right\rVert_{2}^{2}\right]blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log ( divide start_ARG italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT end_ARG ) ( italic_Z ( italic_l italic_h ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] L2dh+L2h2𝔼α[Z(lh)22]less-than-or-similar-toabsentsuperscript𝐿2𝑑superscript𝐿2superscript2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥𝑍𝑙22\displaystyle\lesssim L^{2}dh+L^{2}h^{2}\mathbb{E}_{\alpha}\left[\left\lVert Z% (lh)\right\rVert_{2}^{2}\right]≲ italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_Z ( italic_l italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+(1+L2)h2𝔼α[logpTtiZ(lh)22]1superscript𝐿2superscript2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙22\displaystyle+(1+L^{2})h^{2}\mathbb{E}_{\alpha}\left[\left\lVert\nabla\log p^{% i}_{T-t}Z(lh)\right\rVert_{2}^{2}\right]+ ( 1 + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_l italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
L2dh+L2h2𝔼α[Z(lh)22]less-than-or-similar-toabsentsuperscript𝐿2𝑑superscript𝐿2superscript2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥𝑍𝑙22\displaystyle\lesssim L^{2}dh+L^{2}h^{2}\mathbb{E}_{\alpha}\left[\left\lVert Z% (lh)\right\rVert_{2}^{2}\right]≲ italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_Z ( italic_l italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+L2h2𝔼α[logpTtiZ(lh)22].superscript𝐿2superscript2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙22\displaystyle+L^{2}h^{2}\mathbb{E}_{\alpha}\left[\left\lVert\nabla\log p^{i}_{% T-t}Z(lh)\right\rVert_{2}^{2}\right].+ italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_l italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] .

Moreover, from L𝐿Litalic_L-Lipschitz condition,

logpTtiZ(lh)22superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙22\displaystyle\left\lVert\nabla\log p^{i}_{T-t}Z(lh)\right\rVert_{2}^{2}∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_l italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT logpTtiZ(t)22+logpTtiZ(lh)logpTtiZ(t)22less-than-or-similar-toabsentsuperscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑙subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22\displaystyle\lesssim\left\lVert\nabla\log p^{i}_{T-t}Z(t)\right\rVert_{2}^{2}% +\left\lVert\nabla\log p^{i}_{T-t}Z(lh)-\nabla\log p^{i}_{T-t}Z(t)\right\rVert% _{2}^{2}≲ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_l italic_h ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
logpTtiZ(t)22+L2Z(lh)Z(t)22less-than-or-similar-toabsentsuperscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22superscript𝐿2superscriptsubscriptdelimited-∥∥𝑍𝑙𝑍𝑡22\displaystyle\lesssim\left\lVert\nabla\log p^{i}_{T-t}Z(t)\right\rVert_{2}^{2}% +L^{2}\left\lVert Z(lh)-Z(t)\right\rVert_{2}^{2}≲ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_Z ( italic_l italic_h ) - italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Hence,

\displaystyle\mathcal{L}caligraphic_L =σ2𝔼α[sTlh,θ𝝀(Z(lh))logp𝝀,Tt(Z(t))22]absentsuperscript𝜎2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript𝑝𝝀𝑇𝑡𝑍𝑡22\displaystyle=\sigma^{2}\mathbb{E}_{\alpha}\left[\left\lVert s^{\boldsymbol{% \lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p_{\boldsymbol{% \lambda},T-t}\left(Z(t)\right)\right\rVert_{2}^{2}\right]= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
kσ2ϵscore2+kσ2L2dh+kσ2L2h2𝔼α[Z(lh)22]less-than-or-similar-toabsent𝑘superscript𝜎2superscriptsubscriptitalic-ϵscore2𝑘superscript𝜎2superscript𝐿2𝑑𝑘superscript𝜎2superscript𝐿2superscript2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥𝑍𝑙22\displaystyle\lesssim k\sigma^{2}\epsilon_{\text{score}}^{2}+k\sigma^{2}L^{2}% dh+k\sigma^{2}L^{2}h^{2}\mathbb{E}_{\alpha}\left[\left\lVert Z(lh)\right\rVert% _{2}^{2}\right]≲ italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_Z ( italic_l italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+kσ2L2h2𝔼α[logpTtiZ(t)22]+kσ2L2𝔼α[Z(lh)Z(t)22].𝑘superscript𝜎2superscript𝐿2superscript2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑖𝑇𝑡𝑍𝑡22𝑘superscript𝜎2superscript𝐿2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥𝑍𝑙𝑍𝑡22\displaystyle+k\sigma^{2}L^{2}h^{2}\mathbb{E}_{\alpha}\left[\left\lVert\nabla% \log p^{i}_{T-t}Z(t)\right\rVert_{2}^{2}\right]+k\sigma^{2}L^{2}\mathbb{E}_{% \alpha}\left[\left\lVert Z(lh)-Z(t)\right\rVert_{2}^{2}\right].+ italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_Z ( italic_l italic_h ) - italic_Z ( italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] .

From Lemma 2 and Lemma 3, we have

\displaystyle\mathcal{L}caligraphic_L =σ2𝔼α[sTlh,θ𝝀(Z(lh))logp𝝀,Tt(Z(t))22]absentsuperscript𝜎2subscript𝔼𝛼delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript𝑝𝝀𝑇𝑡𝑍𝑡22\displaystyle=\sigma^{2}\mathbb{E}_{\alpha}\left[\left\lVert s^{\boldsymbol{% \lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p_{\boldsymbol{% \lambda},T-t}\left(Z(t)\right)\right\rVert_{2}^{2}\right]= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
kσ2(ϵscore2+L2dh+L2h2(d+M)+L3dh2+L2(dh+Mh2))less-than-or-similar-toabsent𝑘superscript𝜎2superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑑𝑀superscript𝐿3𝑑superscript2superscript𝐿2𝑑𝑀superscript2\displaystyle\lesssim k\sigma^{2}\left(\epsilon_{\text{score}}^{2}+L^{2}dh+L^{% 2}h^{2}\left(d+M\right)+L^{3}dh^{2}+L^{2}\left(dh+Mh^{2}\right)\right)≲ italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_d + italic_M ) + italic_L start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_d italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_d italic_h + italic_M italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) )
kσ2(ϵscore2+L2dh+L2h2M).less-than-or-similar-toabsent𝑘superscript𝜎2superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀\displaystyle\lesssim k\sigma^{2}\left(\epsilon_{\text{score}}^{2}+L^{2}dh+L^{% 2}h^{2}M\right).≲ italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M ) .

Therefore,

\displaystyle\mathcal{L}caligraphic_L =σ2l=0N1𝔼α[lh(l+1)hsTlh,θ𝝀(Z(lh))logp𝝀,Tt(Z(t))22𝑑t]absentsuperscript𝜎2superscriptsubscript𝑙0𝑁1subscript𝔼𝛼delimited-[]superscriptsubscript𝑙𝑙1superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript𝑝𝝀𝑇𝑡𝑍𝑡22differential-d𝑡\displaystyle=\sigma^{2}\sum_{l=0}^{N-1}\mathbb{E}_{\alpha}\left[\int_{lh}^{(l% +1)h}\left\lVert s^{\boldsymbol{\lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)-% \nabla\log p_{\boldsymbol{\lambda},T-t}\left(Z(t)\right)\right\rVert_{2}^{2}dt\right]= italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT italic_l italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l + 1 ) italic_h end_POSTSUPERSCRIPT ∥ italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_t ]
σ2kT(ϵscore2+L2dh+L2h2M).less-than-or-similar-toabsentsuperscript𝜎2𝑘𝑇superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀\displaystyle\lesssim\sigma^{2}kT\left(\epsilon_{\text{score}}^{2}+L^{2}dh+L^{% 2}h^{2}M\right).≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_T ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M ) .

Next, we claim that

DKL(αβT)kσ2T(ϵscore2+L2dh+L2h2M).less-than-or-similar-tosubscript𝐷KLconditional𝛼subscript𝛽𝑇𝑘superscript𝜎2𝑇superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀D_{\text{KL}}\left(\alpha\parallel\beta_{T}\right)\lesssim k\sigma^{2}T\left(% \epsilon_{\text{score}}^{2}+L^{2}dh+L^{2}h^{2}M\right).italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≲ italic_k italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_T ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M ) . (21)

Then from triangle inequality, Pinsker’s inequality, and data processing inequality,

TV(p𝝀,p^𝝀)TVsubscript𝑝𝝀subscript^𝑝𝝀\displaystyle\text{TV}\left(p_{\boldsymbol{\lambda}},\hat{p}_{\boldsymbol{% \lambda}}\right)TV ( italic_p start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT ) TV(α,β)absentTV𝛼𝛽\displaystyle\leq\text{TV}\left(\alpha,\beta\right)≤ TV ( italic_α , italic_β )
TV(β,βT)+TV(α,βT)absentTV𝛽subscript𝛽𝑇TV𝛼subscript𝛽𝑇\displaystyle\leq\text{TV}\left(\beta,\beta_{T}\right)+\text{TV}\left(\alpha,% \beta_{T}\right)≤ TV ( italic_β , italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) + TV ( italic_α , italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT )
TV(π,γTd)+TV(α,βT)absentTV𝜋subscriptsuperscript𝛾𝑑𝑇TV𝛼subscript𝛽𝑇\displaystyle\leq\text{TV}\left(\pi,\gamma^{d}_{T}\right)+\text{TV}\left(% \alpha,\beta_{T}\right)≤ TV ( italic_π , italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) + TV ( italic_α , italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT )
exp(T)maxi=1,2,,kDKL(pTiπ)+σkT(ϵscore+Ldh+LhM).less-than-or-similar-toabsent𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋𝜎𝑘𝑇subscriptitalic-ϵscore𝐿𝑑𝐿𝑀\displaystyle\lesssim\exp(-T)\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left(p^{% i}_{T}\parallel\pi\right)}+\sigma\sqrt{kT}\left(\epsilon_{\text{score}}+L\sqrt% {dh}+Lh\sqrt{M}\right).≲ roman_exp ( - italic_T ) roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG + italic_σ square-root start_ARG italic_k italic_T end_ARG ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT + italic_L square-root start_ARG italic_d italic_h end_ARG + italic_L italic_h square-root start_ARG italic_M end_ARG ) .

Hence it suffices to prove Equation (21). We will use a localization argument and apply Girsanov’s theorem. The notations are the same as in Theorem 6.

Let t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ], (t)=0tb(s)𝑑B(s)𝑡superscriptsubscript0𝑡𝑏𝑠differential-d𝐵𝑠\mathcal{L}(t)=\int_{0}^{t}b(s)dB(s)caligraphic_L ( italic_t ) = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_b ( italic_s ) italic_d italic_B ( italic_s ), where B𝐵Bitalic_B is an α𝛼\alphaitalic_α-Brownian motion and for t[lh,(l+1)h]𝑡𝑙𝑙1t\in[lh,(l+1)h]italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ],

b(t)=σ(sTlh,θ𝝀(Z(lh))logp𝝀,Tt(Z(t))).𝑏𝑡𝜎subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript𝑝𝝀𝑇𝑡𝑍𝑡b(t)=\sigma\left(s^{\boldsymbol{\lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)-% \nabla\log p_{\boldsymbol{\lambda},T-t}\left(Z(t)\right)\right).italic_b ( italic_t ) = italic_σ ( italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ) .

Recall that

𝔼α[0Tb(s)22𝑑s]kTσ2(ϵscore2+L2dh+L2h2M).less-than-or-similar-tosubscript𝔼𝛼delimited-[]superscriptsubscript0𝑇superscriptsubscriptnorm𝑏𝑠22differential-d𝑠𝑘𝑇superscript𝜎2superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀\mathbb{E}_{\alpha}\left[\int_{0}^{T}\left\|b(s)\right\|_{2}^{2}ds\right]% \lesssim kT\sigma^{2}\left(\epsilon_{\text{score}}^{2}+L^{2}dh+L^{2}h^{2}M% \right).blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ] ≲ italic_k italic_T italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M ) .

Since {()(t)}t[0,T]subscript𝑡𝑡0𝑇\{\mathcal{E}\left(\mathcal{L}\right)(t)\}_{t\in[0,T]}{ caligraphic_E ( caligraphic_L ) ( italic_t ) } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT is a local martingale, then there exists a non-decreasing sequence of stop** times TnTsubscript𝑇𝑛𝑇T_{n}\to Titalic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT → italic_T such that {()(tTn)}t[0,T]subscript𝑡subscript𝑇𝑛𝑡0𝑇\{\mathcal{E}\left(\mathcal{L}\right)(t\wedge T_{n})\}_{t\in[0,T]}{ caligraphic_E ( caligraphic_L ) ( italic_t ∧ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT is a true martingale. Note that ()(tTn)=(n)(t)𝑡subscript𝑇𝑛superscript𝑛𝑡\mathcal{E}\left(\mathcal{L}\right)(t\wedge T_{n})=\mathcal{E}\left(\mathcal{L% }^{n}\right)(t)caligraphic_E ( caligraphic_L ) ( italic_t ∧ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = caligraphic_E ( caligraphic_L start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) ( italic_t ), where n(t)=(tTn)superscript𝑛𝑡𝑡subscript𝑇𝑛\mathcal{L}^{n}(t)=\mathcal{L}(t\wedge T_{n})caligraphic_L start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) = caligraphic_L ( italic_t ∧ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), therefore

𝔼α[(n)(T)]=𝔼α[(n)(0)]=1.subscript𝔼𝛼delimited-[]superscript𝑛𝑇subscript𝔼𝛼delimited-[]superscript𝑛01\mathbb{E}_{\alpha}\left[\mathcal{E}\left(\mathcal{L}^{n}\right)(T)\right]=% \mathbb{E}_{\alpha}\left[\mathcal{E}\left(\mathcal{L}^{n}\right)(0)\right]=1.blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ caligraphic_E ( caligraphic_L start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) ( italic_T ) ] = blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ caligraphic_E ( caligraphic_L start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) ( 0 ) ] = 1 .

Applying Theorem 6 to n(t)=0tb(s)𝟏[0,Tn](s)𝑑B(s)superscript𝑛𝑡superscriptsubscript0𝑡𝑏𝑠subscript10subscript𝑇𝑛𝑠differential-d𝐵𝑠\mathcal{L}^{n}(t)=\int_{0}^{t}b(s)\mathbf{1}_{[0,T_{n}]}(s)dB(s)caligraphic_L start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_b ( italic_s ) bold_1 start_POSTSUBSCRIPT [ 0 , italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_s ) italic_d italic_B ( italic_s ), we have that under the measure Pn=(n)(T)αsuperscript𝑃𝑛superscript𝑛𝑇𝛼P^{n}=\mathcal{E}\left(\mathcal{L}^{n}\right)(T)\alphaitalic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = caligraphic_E ( caligraphic_L start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) ( italic_T ) italic_α, there exists a Brownian motion βnsuperscript𝛽𝑛\beta^{n}italic_β start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT such that for all t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ],

dB(t)=σ(sTlh,θ𝝀(Z(lh))logp𝝀,Tt(Z(t)))𝟏[0,Tn](t)dt+dβn(t).𝑑𝐵𝑡𝜎subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript𝑝𝝀𝑇𝑡𝑍𝑡subscript10subscript𝑇𝑛𝑡𝑑𝑡𝑑superscript𝛽𝑛𝑡dB(t)=\sigma\left(s^{\boldsymbol{\lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)% -\nabla\log p_{\boldsymbol{\lambda},T-t}\left(Z(t)\right)\right)\mathbf{1}_{[0% ,T_{n}]}(t)dt+d\beta^{n}(t).italic_d italic_B ( italic_t ) = italic_σ ( italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ) bold_1 start_POSTSUBSCRIPT [ 0 , italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_t ) italic_d italic_t + italic_d italic_β start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) .

Since under α𝛼\alphaitalic_α we have almost surely

dZ(t)=(aZ(t)+σ2logp𝝀,Tt(Z(t)))dt+σdB(t),Z(0)γd,formulae-sequence𝑑𝑍𝑡𝑎𝑍𝑡superscript𝜎2subscript𝑝𝝀𝑇𝑡𝑍𝑡𝑑𝑡𝜎𝑑𝐵𝑡similar-to𝑍0superscript𝛾𝑑dZ(t)=\left(aZ(t)+\sigma^{2}\nabla\log p_{\boldsymbol{\lambda},T-t}\left(Z(t)% \right)\right)dt+\sigma dB(t),Z(0)\sim\gamma^{d},italic_d italic_Z ( italic_t ) = ( italic_a italic_Z ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d italic_B ( italic_t ) , italic_Z ( 0 ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ,

which also holds Pnsuperscript𝑃𝑛P^{n}italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT-almost surely since Pnαmuch-less-thansuperscript𝑃𝑛𝛼P^{n}\ll\alphaitalic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ≪ italic_α. Therefore, Pnsuperscript𝑃𝑛P^{n}italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT-almost surely, Z(0)γdsimilar-to𝑍0superscript𝛾𝑑Z(0)\sim\gamma^{d}italic_Z ( 0 ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and

dZ(t)𝑑𝑍𝑡\displaystyle dZ(t)italic_d italic_Z ( italic_t ) =[aZ(t)+σ2sTlh,θ𝝀(Z(lh))]𝟏[0,Tn]dtabsentdelimited-[]𝑎𝑍𝑡superscript𝜎2subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃𝑍𝑙subscript10subscript𝑇𝑛𝑑𝑡\displaystyle=\left[aZ(t)+\sigma^{2}s^{\boldsymbol{\lambda}}_{T-lh,\theta^{*}}% \left(Z(lh)\right)\right]\mathbf{1}_{[0,T_{n}]}dt= [ italic_a italic_Z ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) ] bold_1 start_POSTSUBSCRIPT [ 0 , italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT italic_d italic_t
+[aZ(t)logp𝝀,Tt(Z(t))]𝟏[Tn,T]dt+σdβ(t).delimited-[]𝑎𝑍𝑡subscript𝑝𝝀𝑇𝑡𝑍𝑡subscript1subscript𝑇𝑛𝑇𝑑𝑡𝜎𝑑𝛽𝑡\displaystyle+\left[aZ(t)\nabla\log p_{\boldsymbol{\lambda},T-t}\left(Z(t)% \right)\right]\mathbf{1}_{[T_{n},T]}dt+\sigma d\beta(t).+ [ italic_a italic_Z ( italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ] bold_1 start_POSTSUBSCRIPT [ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT italic_d italic_t + italic_σ italic_d italic_β ( italic_t ) .

In other words, Pnsuperscript𝑃𝑛P^{n}italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT is the law of the solution of the above SDE. Plugging in the Radon-Nikodym derivatives, we get

DKL(αPn)subscript𝐷KLconditional𝛼superscript𝑃𝑛\displaystyle D_{\text{KL}}\left(\alpha\parallel P^{n}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) =𝔼α[log(dαdPn)]absentsubscript𝔼𝛼delimited-[]𝑑𝛼𝑑superscript𝑃𝑛\displaystyle=\mathbb{E}_{\alpha}\left[\log\left(\frac{d\alpha}{dP^{n}}\right)\right]= blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ roman_log ( divide start_ARG italic_d italic_α end_ARG start_ARG italic_d italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_ARG ) ]
=𝔼α[log(1()(Tn))]absentsubscript𝔼𝛼delimited-[]1subscript𝑇𝑛\displaystyle=\mathbb{E}_{\alpha}\left[\log\left(\frac{1}{\mathcal{E}\left(% \mathcal{L}\right)(T_{n})}\right)\right]= blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ roman_log ( divide start_ARG 1 end_ARG start_ARG caligraphic_E ( caligraphic_L ) ( italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_ARG ) ]
=𝔼α[(Tn)+120Tnb(s)22𝑑s]absentsubscript𝔼𝛼delimited-[]subscript𝑇𝑛12superscriptsubscript0subscript𝑇𝑛superscriptsubscriptnorm𝑏𝑠22differential-d𝑠\displaystyle=\mathbb{E}_{\alpha}\left[-\mathcal{L}(T_{n})+\frac{1}{2}\int_{0}% ^{T_{n}}\left\|b(s)\right\|_{2}^{2}ds\right]= blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ - caligraphic_L ( italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ]
=𝔼α[120Tnb(s)22𝑑s]absentsubscript𝔼𝛼delimited-[]12superscriptsubscript0subscript𝑇𝑛superscriptsubscriptnorm𝑏𝑠22differential-d𝑠\displaystyle=\mathbb{E}_{\alpha}\left[\frac{1}{2}\int_{0}^{T_{n}}\left\|b(s)% \right\|_{2}^{2}ds\right]= blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ]
𝔼α[120Tb(s)22𝑑s]absentsubscript𝔼𝛼delimited-[]12superscriptsubscript0𝑇superscriptsubscriptnorm𝑏𝑠22differential-d𝑠\displaystyle\leq\mathbb{E}_{\alpha}\left[\frac{1}{2}\int_{0}^{T}\left\|b(s)% \right\|_{2}^{2}ds\right]≤ blackboard_E start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_b ( italic_s ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_s ]
kTσ2(ϵscore2+L2dh+L2h2M)less-than-or-similar-toabsent𝑘𝑇superscript𝜎2superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀\displaystyle\lesssim kT\sigma^{2}\left(\epsilon_{\text{score}}^{2}+L^{2}dh+L^% {2}h^{2}M\right)≲ italic_k italic_T italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M )

since (Tn)subscript𝑇𝑛\mathcal{L}(T_{n})caligraphic_L ( italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) is a martingale and Tnsubscript𝑇𝑛T_{n}italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is a bounded stop** time (apply optional sampling theorem).

Now consider a coupling of (Pn)nsubscriptsuperscript𝑃𝑛𝑛\left(P^{n}\right)_{n\in\mathbb{N}}( italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT, βTsubscript𝛽𝑇\beta_{T}italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT: a sequence of stochastic processes (Zn)nsubscriptsuperscript𝑍𝑛𝑛\left(Z^{n}\right)_{n\in\mathbb{N}}( italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT over the same probability space, a stochastic process Z𝑍Zitalic_Z and a single Brownian motion W𝑊Witalic_W over that space such that Z(0)=Zn(0)𝑍0superscript𝑍𝑛0Z(0)=Z^{n}(0)italic_Z ( 0 ) = italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( 0 ) almost surely, Z(0)γdsimilar-to𝑍0superscript𝛾𝑑Z(0)\sim\gamma^{d}italic_Z ( 0 ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT,

dZn(t)𝑑superscript𝑍𝑛𝑡\displaystyle dZ^{n}(t)italic_d italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) =[aZn(t)+σ2sTlh,θ𝝀(Zn(lh))]𝟏[0,Tn]dtabsentdelimited-[]𝑎superscript𝑍𝑛𝑡superscript𝜎2subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃superscript𝑍𝑛𝑙subscript10subscript𝑇𝑛𝑑𝑡\displaystyle=\left[aZ^{n}(t)+\sigma^{2}s^{\boldsymbol{\lambda}}_{T-lh,\theta^% {*}}\left(Z^{n}(lh)\right)\right]\mathbf{1}_{[0,T_{n}]}dt= [ italic_a italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_l italic_h ) ) ] bold_1 start_POSTSUBSCRIPT [ 0 , italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT italic_d italic_t
+[aZn(t)+logp𝝀,Tt(Zn(t))]𝟏[Tn,T]dt+σdW(t),delimited-[]𝑎superscript𝑍𝑛𝑡subscript𝑝𝝀𝑇𝑡superscript𝑍𝑛𝑡subscript1subscript𝑇𝑛𝑇𝑑𝑡𝜎𝑑𝑊𝑡\displaystyle+\left[aZ^{n}(t)+\nabla\log p_{\boldsymbol{\lambda},T-t}\left(Z^{% n}(t)\right)\right]\mathbf{1}_{[T_{n},T]}dt+\sigma dW(t),+ [ italic_a italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) + ∇ roman_log italic_p start_POSTSUBSCRIPT bold_italic_λ , italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) ) ] bold_1 start_POSTSUBSCRIPT [ italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) ,

and

dZ(t)=[aZ(t)+σ2sTlh,θ𝝀(Zn(lh))]dt+σdW(t).𝑑𝑍𝑡delimited-[]𝑎𝑍𝑡superscript𝜎2subscriptsuperscript𝑠𝝀𝑇𝑙superscript𝜃superscript𝑍𝑛𝑙𝑑𝑡𝜎𝑑𝑊𝑡dZ(t)=\left[aZ(t)+\sigma^{2}s^{\boldsymbol{\lambda}}_{T-lh,\theta^{*}}\left(Z^% {n}(lh)\right)\right]dt+\sigma dW(t).italic_d italic_Z ( italic_t ) = [ italic_a italic_Z ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_s start_POSTSUPERSCRIPT bold_italic_λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_l italic_h ) ) ] italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) .

Hence law of Znsuperscript𝑍𝑛Z^{n}italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT is Pnsuperscript𝑃𝑛P^{n}italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT and law of Z𝑍Zitalic_Z is βTsubscript𝛽𝑇\beta_{T}italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT. The existence of such coupling is shown in Chen et al. [12].

Fix ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, define the map πϵ:C([0,T]:d)C([0,T]:d)\pi_{\epsilon}:C([0,T]:\mathbb{R}^{d})\to C([0,T]:\mathbb{R}^{d})italic_π start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT : italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) → italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) such that

πϵ(ω)(t)=ω(tTϵ).subscript𝜋italic-ϵ𝜔𝑡𝜔𝑡𝑇italic-ϵ\pi_{\epsilon}(\omega)(t)=\omega\left(t\wedge T-\epsilon\right).italic_π start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_ω ) ( italic_t ) = italic_ω ( italic_t ∧ italic_T - italic_ϵ ) .

Since for each t[0,Tn]𝑡0subscript𝑇𝑛t\in[0,T_{n}]italic_t ∈ [ 0 , italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ], Zn(t)=Z(t)superscript𝑍𝑛𝑡𝑍𝑡Z^{n}(t)=Z(t)italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t ) = italic_Z ( italic_t ), then from Lemma 4, we have πϵ(Zn)πϵ(Z)subscript𝜋italic-ϵsuperscript𝑍𝑛subscript𝜋italic-ϵ𝑍\pi_{\epsilon}\left(Z^{n}\right)\to\pi_{\epsilon}\left(Z\right)italic_π start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) → italic_π start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_Z ) almost surely uniformly over [0,T]0𝑇[0,T][ 0 , italic_T ], which implies that πϵ#Pnπϵ#βTsubscript𝜋italic-ϵ#superscript𝑃𝑛subscript𝜋italic-ϵ#subscript𝛽𝑇\pi_{\epsilon\text{}\#}P^{n}\to\pi_{\epsilon\text{}\#}\beta_{T}italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT weakly.

Since KL divergence is lower semicontinuous, then from data processing inequality, we have

DKL(πϵ#απϵ#βT)subscript𝐷KLconditionalsubscript𝜋italic-ϵ#𝛼subscript𝜋italic-ϵ#subscript𝛽𝑇\displaystyle D_{\text{KL}}\left(\pi_{\epsilon\text{}\#}\alpha\parallel\pi_{% \epsilon\text{}\#}\beta_{T}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_α ∥ italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) lim infnDKL(πϵ#απϵ#Pn)absentsubscriptlimit-infimum𝑛subscript𝐷KLconditionalsubscript𝜋italic-ϵ#𝛼subscript𝜋italic-ϵ#superscript𝑃𝑛\displaystyle\leq\liminf_{n\to\infty}D_{\text{KL}}\left(\pi_{\epsilon\text{}\#% }\alpha\parallel\pi_{\epsilon\text{}\#}P^{n}\right)≤ lim inf start_POSTSUBSCRIPT italic_n → ∞ end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_α ∥ italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT )
DKL(αPn)absentsubscript𝐷KLconditional𝛼superscript𝑃𝑛\displaystyle\leq D_{\text{KL}}\left(\alpha\parallel P^{n}\right)≤ italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_P start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT )
kTσ2(ϵscore2+L2dh+L2h2M).less-than-or-similar-toabsent𝑘𝑇superscript𝜎2superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀\displaystyle\lesssim kT\sigma^{2}\left(\epsilon_{\text{score}}^{2}+L^{2}dh+L^% {2}h^{2}M\right).≲ italic_k italic_T italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M ) .

From Lemma 5, as ϵ0italic-ϵ0\epsilon\to 0italic_ϵ → 0, πϵ(ω)ωsubscript𝜋italic-ϵ𝜔𝜔\pi_{\epsilon}(\omega)\to\omegaitalic_π start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_ω ) → italic_ω uniformly over [0,T]0𝑇[0,T][ 0 , italic_T ]. Hence, from Corollary 9.4.6 in Ambrosio et al. [2], as ϵ0italic-ϵ0\epsilon\to 0italic_ϵ → 0, DKL(πϵ#απϵ#βT)DKL(αβT)subscript𝐷KLconditionalsubscript𝜋italic-ϵ#𝛼subscript𝜋italic-ϵ#subscript𝛽𝑇subscript𝐷KLconditional𝛼subscript𝛽𝑇D_{\text{KL}}\left(\pi_{\epsilon\text{}\#}\alpha\parallel\pi_{\epsilon\text{}% \#}\beta_{T}\right)\to D_{\text{KL}}\left(\alpha\parallel\beta_{T}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_α ∥ italic_π start_POSTSUBSCRIPT italic_ϵ # end_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) → italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ). Therefore, from Pinsker’s inequality,

DKL(αβT)kTσ2(ϵscore2+L2dh+L2h2M).less-than-or-similar-tosubscript𝐷KLconditional𝛼subscript𝛽𝑇𝑘𝑇superscript𝜎2superscriptsubscriptitalic-ϵscore2superscript𝐿2𝑑superscript𝐿2superscript2𝑀D_{\text{KL}}\left(\alpha\parallel\beta_{T}\right)\lesssim kT\sigma^{2}\left(% \epsilon_{\text{score}}^{2}+L^{2}dh+L^{2}h^{2}M\right).italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α ∥ italic_β start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ≲ italic_k italic_T italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_h + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_M ) .

   

Before the proof, we introduce some notations that will only be used for the proof of Theorem 3. Recall that the vanilla fusion method requires two layers of approximation before running the Frank-Wolfe method: we use target samples to estimate an expectation and we also estimate the densities of auxiliaries. As a notation, we denote p¯^𝝀^subscript^¯𝑝^𝝀\hat{\bar{p}}_{\hat{\boldsymbol{\lambda}}}over^ start_ARG over¯ start_ARG italic_p end_ARG end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT as the distribution of the generated sample by vanilla fusion, which is ν^Dsubscript^𝜈𝐷\hat{\nu}_{D}over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT in Section 4. 𝝀^bold-^𝝀\boldsymbol{\hat{\lambda}}overbold_^ start_ARG bold_italic_λ end_ARG is the weight computed with n𝑛nitalic_n target samples, p𝝀^subscript𝑝^𝝀p_{\hat{\boldsymbol{\lambda}}}italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT denotes the barycenter of {μ1,,μk}subscript𝜇1subscript𝜇𝑘\{\mu_{1},\ldots,\mu_{k}\}{ italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } with the weight 𝝀^^𝝀\hat{\boldsymbol{\lambda}}over^ start_ARG bold_italic_λ end_ARG, and p¯𝝀^subscript¯𝑝bold-^𝝀\bar{p}_{\boldsymbol{\hat{\boldsymbol{\lambda}}}}over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT overbold_^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT denotes the barycenter of {p¯1,,p¯k}subscript¯𝑝1subscript¯𝑝𝑘\{\bar{p}_{1},\ldots,\bar{p}_{k}\}{ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } with the weight 𝝀^^𝝀\hat{\boldsymbol{\lambda}}over^ start_ARG bold_italic_λ end_ARG, where {p¯1,,p¯k}subscript¯𝑝1subscript¯𝑝𝑘\{\bar{p}_{1},\ldots,\bar{p}_{k}\}{ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } is the collection of estimates of auxiliary densities. Note that p¯𝝀^μ^𝝀similar-tosubscript¯𝑝bold-^𝝀subscript^𝜇𝝀\bar{p}_{\boldsymbol{\hat{\boldsymbol{\lambda}}}}\sim\hat{\mu}_{\boldsymbol{% \lambda}}over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT overbold_^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ∼ over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT bold_italic_λ end_POSTSUBSCRIPT in Section 4.

Proof.

From triangle inequality, we have

TV(ν,p¯^𝝀^)TV𝜈subscript^¯𝑝^𝝀\displaystyle\text{TV}\left(\nu,\hat{\bar{p}}_{\hat{\boldsymbol{\lambda}}}\right)TV ( italic_ν , over^ start_ARG over¯ start_ARG italic_p end_ARG end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ) TV(ν,p𝝀^)+TV(p𝝀^,p¯𝝀^)+TV(p¯𝝀^,p¯^𝝀^)absentTV𝜈subscript𝑝^𝝀TVsubscript𝑝bold-^𝝀subscript¯𝑝bold-^𝝀TVsubscript¯𝑝bold-^𝝀subscript^¯𝑝^𝝀\displaystyle\leq\text{TV}\left(\nu,p_{\hat{\boldsymbol{\lambda}}}\right)+% \text{TV}\left(p_{\boldsymbol{\hat{\boldsymbol{\lambda}}}},\bar{p}_{% \boldsymbol{\hat{\boldsymbol{\lambda}}}}\right)+\text{TV}\left(\bar{p}_{% \boldsymbol{\hat{\boldsymbol{\lambda}}}},\hat{\bar{p}}_{\hat{\boldsymbol{% \lambda}}}\right)≤ TV ( italic_ν , italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ) + TV ( italic_p start_POSTSUBSCRIPT overbold_^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT , over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT overbold_^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ) + TV ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT overbold_^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG over¯ start_ARG italic_p end_ARG end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT )
:=I1+I2+I3,assignabsentsubscript𝐼1subscript𝐼2subscript𝐼3\displaystyle:=I_{1}+I_{2}+I_{3},:= italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_I start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ,

where I1subscript𝐼1I_{1}italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT represents the error when computing using the Frank-Wolfe method, I2ϵ2subscript𝐼2subscriptitalic-ϵ2I_{2}\leq\epsilon_{2}italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_ϵ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT by assumption, and I3subscript𝐼3I_{3}italic_I start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT is the error from auxiliary score estimations, which is bounded by Lemma 7.

Therefore it only remains to bound I1subscript𝐼1I_{1}italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. From Pinsker’s inequality,

I1=TV(ν,p𝝀^)DKL(νp𝝀^),subscript𝐼1TV𝜈subscript𝑝^𝝀less-than-or-similar-tosubscript𝐷KLconditional𝜈subscript𝑝^𝝀I_{1}=\text{TV}\left(\nu,p_{\hat{\boldsymbol{\lambda}}}\right)\lesssim\sqrt{D_% {\text{KL}}\left(\nu\parallel p_{\hat{\boldsymbol{\lambda}}}\right)},italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = TV ( italic_ν , italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ) ≲ square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ) end_ARG ,

hence it is enough to bound DKL(νp𝝀^)subscript𝐷KLconditional𝜈subscript𝑝^𝝀D_{\text{KL}}\left(\nu\parallel p_{\hat{\boldsymbol{\lambda}}}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ). From the compactedness assumption, we note that the objective function F𝐹Fitalic_F of problem (9) is L~~𝐿\tilde{L}over~ start_ARG italic_L end_ARG-smooth for some constant L~~𝐿\tilde{L}over~ start_ARG italic_L end_ARG. Since the simplex in real space is convex, we denote the diameter of constrain set as D𝐷Ditalic_D.

We denote 𝝀^(τ)^𝝀𝜏\hat{\boldsymbol{\lambda}}(\tau)over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) as the weight computed after τ𝜏\tauitalic_τ iterations with n𝑛nitalic_n target samples, then we claim that for τ1𝜏1\tau\geq 1italic_τ ≥ 1, and δ>0𝛿0\delta>0italic_δ > 0, with probability at least 1δ1𝛿1-\delta1 - italic_δ,

hτsubscript𝜏\displaystyle h_{\tau}italic_h start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT =F(𝝀^(τ))F(𝝀)=DKL(νp𝝀^(τ))DKL(νp𝝀)absent𝐹^𝝀𝜏𝐹superscript𝝀subscript𝐷KLconditional𝜈subscript𝑝^𝝀𝜏subscript𝐷KLconditional𝜈subscript𝑝superscript𝝀\displaystyle=F\left(\hat{\boldsymbol{\lambda}}(\tau)\right)-F(\boldsymbol{% \lambda}^{*})=D_{\text{KL}}\left(\nu\parallel p_{\hat{\boldsymbol{\lambda}}(% \tau)}\right)-D_{\text{KL}}\left(\nu\parallel p_{\boldsymbol{\lambda}^{*}}\right)= italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ) - italic_F ( bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) end_POSTSUBSCRIPT ) - italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )
2L~D2τ+3+𝒪((log(1δ))1/2n1/2).less-than-or-similar-toabsent2~𝐿superscript𝐷2𝜏3𝒪superscript1𝛿12superscript𝑛12\displaystyle\lesssim\frac{2\tilde{L}D^{2}}{\tau+3}+\mathcal{O}\left(\left(% \log\left(\frac{1}{\delta}\right)\right)^{1/2}n^{-1/2}\right).≲ divide start_ARG 2 over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_τ + 3 end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) . (22)

We will use an induction argument to show Equation (C.2). The main estimation is based on the smoothness of F𝐹Fitalic_F and compactness of the constrain set. Let δ>0𝛿0\delta>0italic_δ > 0, then from Hoeffding’s inequality, with probability at least 1δ1𝛿1-\delta1 - italic_δ,

F(𝝀^(τ+1))F(𝝀^(τ))𝐹^𝝀𝜏1𝐹^𝝀𝜏\displaystyle F\left(\hat{\boldsymbol{\lambda}}(\tau+1)\right)-F\left(\hat{% \boldsymbol{\lambda}}(\tau)\right)italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ + 1 ) ) - italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ) F(𝝀^(τ)),𝝀^(τ+1)𝝀^(τ)+L~2𝝀^(τ+1)𝝀^(τ)22absent𝐹^𝝀𝜏^𝝀𝜏1^𝝀𝜏~𝐿2superscriptsubscriptnorm^𝝀𝜏1^𝝀𝜏22\displaystyle\leq\left\langle\nabla F\left(\hat{\boldsymbol{\lambda}}(\tau)% \right),\hat{\boldsymbol{\lambda}}(\tau+1)-\hat{\boldsymbol{\lambda}}(\tau)% \right\rangle+\frac{\tilde{L}}{2}\left\|\hat{\boldsymbol{\lambda}}(\tau+1)-% \hat{\boldsymbol{\lambda}}(\tau)\right\|_{2}^{2}≤ ⟨ ∇ italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ) , over^ start_ARG bold_italic_λ end_ARG ( italic_τ + 1 ) - over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ⟩ + divide start_ARG over~ start_ARG italic_L end_ARG end_ARG start_ARG 2 end_ARG ∥ over^ start_ARG bold_italic_λ end_ARG ( italic_τ + 1 ) - over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=γτF(𝝀^(τ)),vτ𝝀^(τ)+L~γτ22vτ𝝀^(τ)22absentsubscript𝛾𝜏𝐹^𝝀𝜏subscript𝑣𝜏^𝝀𝜏~𝐿superscriptsubscript𝛾𝜏22superscriptsubscriptnormsubscript𝑣𝜏^𝝀𝜏22\displaystyle=\gamma_{\tau}\left\langle\nabla F\left(\hat{\boldsymbol{\lambda}% }(\tau)\right),v_{\tau}-\hat{\boldsymbol{\lambda}}(\tau)\right\rangle+\frac{% \tilde{L}\gamma_{\tau}^{2}}{2}\left\|v_{\tau}-\hat{\boldsymbol{\lambda}}(\tau)% \right\|_{2}^{2}= italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ⟨ ∇ italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ) , italic_v start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT - over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ⟩ + divide start_ARG over~ start_ARG italic_L end_ARG italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_v start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT - over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
γτF(𝝀^(τ)),𝝀𝝀^(τ)+L~γτ2D22less-than-or-similar-toabsentsubscript𝛾𝜏𝐹^𝝀𝜏superscript𝝀^𝝀𝜏~𝐿superscriptsubscript𝛾𝜏2superscript𝐷22\displaystyle\lesssim\gamma_{\tau}\left\langle\nabla F\left(\hat{\boldsymbol{% \lambda}}(\tau)\right),\boldsymbol{\lambda}^{*}-\hat{\boldsymbol{\lambda}}(% \tau)\right\rangle+\frac{\tilde{L}\gamma_{\tau}^{2}D^{2}}{2}≲ italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ⟨ ∇ italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ) , bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT - over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ⟩ + divide start_ARG over~ start_ARG italic_L end_ARG italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG
+𝒪((log(1δ))1/2n1/2)γτ𝝀𝝀^(τ)2𝒪superscript1𝛿12superscript𝑛12subscript𝛾𝜏subscriptnormsuperscript𝝀^𝝀𝜏2\displaystyle+\mathcal{O}\left(\left(\log\left(\frac{1}{\delta}\right)\right)^% {1/2}n^{-1/2}\right)\gamma_{\tau}\left\|\boldsymbol{\lambda}^{*}-\hat{% \boldsymbol{\lambda}}(\tau)\right\|_{2}+ caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ∥ bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT - over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
γτ(F(𝝀)F(𝝀^(τ)))+L~γτ2D22absentsubscript𝛾𝜏𝐹superscript𝝀𝐹^𝝀𝜏~𝐿superscriptsubscript𝛾𝜏2superscript𝐷22\displaystyle\leq\gamma_{\tau}\left(F\left(\boldsymbol{\lambda}^{*}\right)-F% \left(\hat{\boldsymbol{\lambda}}(\tau)\right)\right)+\frac{\tilde{L}\gamma_{% \tau}^{2}D^{2}}{2}≤ italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_F ( bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) - italic_F ( over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) ) ) + divide start_ARG over~ start_ARG italic_L end_ARG italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG
+𝒪((log(1δ))1/2n1/2).𝒪superscript1𝛿12superscript𝑛12\displaystyle+\mathcal{O}\left(\left(\log\left(\frac{1}{\delta}\right)\right)^% {1/2}n^{-1/2}\right).+ caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) .

By rearranging the terms, we get

hτ+1(1γτ)hτ+γτ2L~D22+𝒪((log(1δ))1/2n1/2).less-than-or-similar-tosubscript𝜏11subscript𝛾𝜏subscript𝜏superscriptsubscript𝛾𝜏2~𝐿superscript𝐷22𝒪superscript1𝛿12superscript𝑛12h_{\tau+1}\lesssim\left(1-\gamma_{\tau}\right)h_{\tau}+\gamma_{\tau}^{2}\frac{% \tilde{L}D^{2}}{2}+\mathcal{O}\left(\left(\log\left(\frac{1}{\delta}\right)% \right)^{1/2}n^{-1/2}\right).italic_h start_POSTSUBSCRIPT italic_τ + 1 end_POSTSUBSCRIPT ≲ ( 1 - italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) italic_h start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT + italic_γ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) . (23)

Now we begin the induction argument. If τ=1𝜏1\tau=1italic_τ = 1, then Equation (23) becomes

h1L~D22+𝒪((log(1δ))1/2n1/2),less-than-or-similar-tosubscript1~𝐿superscript𝐷22𝒪superscript1𝛿12superscript𝑛12h_{1}\lesssim\frac{\tilde{L}D^{2}}{2}+\mathcal{O}\left(\left(\log\left(\frac{1% }{\delta}\right)\right)^{1/2}n^{-1/2}\right),italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≲ divide start_ARG over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) ,

which is Equation (C.2), hence base case is shown. Now suppose there exists τ𝜏\tauitalic_τ such that Equation (C.2) holds, then from Equation (23)

hτ+1subscript𝜏1\displaystyle h_{\tau+1}italic_h start_POSTSUBSCRIPT italic_τ + 1 end_POSTSUBSCRIPT (12τ+3)2L~D2τ+3+4(τ+3)2L~D22+𝒪((log(1δ))1/2n1/2)less-than-or-similar-toabsent12𝜏32~𝐿superscript𝐷2𝜏34superscript𝜏32~𝐿superscript𝐷22𝒪superscript1𝛿12superscript𝑛12\displaystyle\lesssim\left(1-\frac{2}{\tau+3}\right)\frac{2\tilde{L}D^{2}}{% \tau+3}+\frac{4}{(\tau+3)^{2}}\frac{\tilde{L}D^{2}}{2}+\mathcal{O}\left(\left(% \log\left(\frac{1}{\delta}\right)\right)^{1/2}n^{-1/2}\right)≲ ( 1 - divide start_ARG 2 end_ARG start_ARG italic_τ + 3 end_ARG ) divide start_ARG 2 over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_τ + 3 end_ARG + divide start_ARG 4 end_ARG start_ARG ( italic_τ + 3 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT )
=2L~D2(τ+2)(τ+3)2+𝒪((log(1δ))1/2n1/2)absent2~𝐿superscript𝐷2𝜏2superscript𝜏32𝒪superscript1𝛿12superscript𝑛12\displaystyle=\frac{2\tilde{L}D^{2}(\tau+2)}{(\tau+3)^{2}}+\mathcal{O}\left(% \left(\log\left(\frac{1}{\delta}\right)\right)^{1/2}n^{-1/2}\right)= divide start_ARG 2 over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_τ + 2 ) end_ARG start_ARG ( italic_τ + 3 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT )
2L~D2τ+4+𝒪((log(1δ))1/2n1/2)absent2~𝐿superscript𝐷2𝜏4𝒪superscript1𝛿12superscript𝑛12\displaystyle\leq\frac{2\tilde{L}D^{2}}{\tau+4}+\mathcal{O}\left(\left(\log% \left(\frac{1}{\delta}\right)\right)^{1/2}n^{-1/2}\right)≤ divide start_ARG 2 over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_τ + 4 end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT )

since (τ+2)(τ+4)(τ+3)2𝜏2𝜏4superscript𝜏32(\tau+2)(\tau+4)\leq(\tau+3)^{2}( italic_τ + 2 ) ( italic_τ + 4 ) ≤ ( italic_τ + 3 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Hence Equation (C.2) is proved and if we let τ𝜏\tau\to\inftyitalic_τ → ∞, then

DKL(νp𝝀^(τ))subscript𝐷KLconditional𝜈subscript𝑝^𝝀𝜏\displaystyle D_{\text{KL}}\left(\nu\parallel p_{\hat{\boldsymbol{\lambda}}(% \tau)}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG ( italic_τ ) end_POSTSUBSCRIPT ) DKL(νp𝝀)+2L~D2τ+3+𝒪((log(1δ))1/2n1/2)absentsubscript𝐷KLconditional𝜈subscript𝑝superscript𝝀2~𝐿superscript𝐷2𝜏3𝒪superscript1𝛿12superscript𝑛12\displaystyle\leq D_{\text{KL}}\left(\nu\parallel p_{\boldsymbol{\lambda}^{*}}% \right)+\frac{2\tilde{L}D^{2}}{\tau+3}+\mathcal{O}\left(\left(\log\left(\frac{% 1}{\delta}\right)\right)^{1/2}n^{-1/2}\right)≤ italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) + divide start_ARG 2 over~ start_ARG italic_L end_ARG italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_τ + 3 end_ARG + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT )
ϵ02+𝒪((log(1δ))1/2n1/2).less-than-or-similar-toabsentsuperscriptsubscriptitalic-ϵ02𝒪superscript1𝛿12superscript𝑛12\displaystyle\lesssim\epsilon_{0}^{2}+\mathcal{O}\left(\left(\log\left(\frac{1% }{\delta}\right)\right)^{1/2}n^{-1/2}\right).≲ italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) .

Therefore, from Pinsker’s inequality, with probability at least 1δ1𝛿1-\delta1 - italic_δ,

TV(ν,p¯^𝝀^)TV𝜈subscript^¯𝑝^𝝀\displaystyle\text{TV}\left(\nu,\hat{\bar{p}}_{\hat{\boldsymbol{\lambda}}}\right)TV ( italic_ν , over^ start_ARG over¯ start_ARG italic_p end_ARG end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_italic_λ end_ARG end_POSTSUBSCRIPT ) ϵ0+ϵ2+exp(T)maxi=1,2,,kDKL(pTiπ)+σkT(ϵscore+Ldh+LhM)less-than-or-similar-toabsentsubscriptitalic-ϵ0subscriptitalic-ϵ2𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋𝜎𝑘𝑇subscriptitalic-ϵscore𝐿𝑑𝐿𝑀\displaystyle\lesssim\epsilon_{0}+\epsilon_{2}+\exp(-T)\max_{i=1,2,\ldots,k}% \sqrt{D_{\text{KL}}\left(p^{i}_{T}\parallel\pi\right)}+\sigma\sqrt{kT}\left(% \epsilon_{\text{score}}+L\sqrt{dh}+Lh\sqrt{M}\right)≲ italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ϵ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + roman_exp ( - italic_T ) roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG + italic_σ square-root start_ARG italic_k italic_T end_ARG ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT + italic_L square-root start_ARG italic_d italic_h end_ARG + italic_L italic_h square-root start_ARG italic_M end_ARG )
+𝒪((log(1δ))1/4n1/4).𝒪superscript1𝛿14superscript𝑛14\displaystyle+\mathcal{O}\left(\left(\log\left(\frac{1}{\delta}\right)\right)^% {1/4}n^{-1/4}\right).+ caligraphic_O ( ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 4 end_POSTSUPERSCRIPT ) .

   

C.3 Proof of Theorem 4

Before the proof, we define notations that will be used in this proof. p^𝚲^subscript^𝑝^𝚲\hat{p}_{\hat{\boldsymbol{\Lambda}}}over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT denotes the output distribution of Algorithm 1, which is ν^Psubscript^𝜈𝑃\hat{\nu}_{P}over^ start_ARG italic_ν end_ARG start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT in Section 4. For a fixed small T~1much-less-than~𝑇1\tilde{T}\ll 1over~ start_ARG italic_T end_ARG ≪ 1, in the training phase of ScoreFusion, we have the forward process for t[0,T~]𝑡0~𝑇t\in[0,\tilde{T}]italic_t ∈ [ 0 , over~ start_ARG italic_T end_ARG ],

dZ(t)=aZ(t)dt+σdW(t),X(0)ν.formulae-sequence𝑑𝑍𝑡𝑎𝑍𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-to𝑋0𝜈dZ(t)=-aZ(t)dt+\sigma dW(t),X(0)\sim\nu.italic_d italic_Z ( italic_t ) = - italic_a italic_Z ( italic_t ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_X ( 0 ) ∼ italic_ν . (24)

We learn an optimal weight by solving problem (10). We denote the marginal distribution of process (24) at time t𝑡titalic_t for fixed 𝚲𝚲\boldsymbol{\Lambda}bold_Λ as ptνsubscriptsuperscript𝑝𝜈𝑡p^{\nu}_{t}italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Even though in practice we do not use the backward process of process (24), the following two versions of backward processes will help in the proof of Theorem 4: for t[0,T~]𝑡0~𝑇t\in[0,\tilde{T}]italic_t ∈ [ 0 , over~ start_ARG italic_T end_ARG ] with Z~(0)γT~dZ^(0)similar-to~𝑍0subscriptsuperscript𝛾𝑑~𝑇similar-to^𝑍0\tilde{Z}(0)\sim\gamma^{d}_{\tilde{T}}\sim\hat{Z}(0)over~ start_ARG italic_Z end_ARG ( 0 ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT ∼ over^ start_ARG italic_Z end_ARG ( 0 ), and fixed 𝚲𝚲\boldsymbol{\Lambda}bold_Λ,

dZ~(t)=(aZ~(t)+σ2logpTtν(Z~(t)))dt+σdW(t),Z~(T~)ν,formulae-sequence𝑑~𝑍𝑡𝑎~𝑍𝑡superscript𝜎2subscriptsuperscript𝑝𝜈𝑇𝑡~𝑍𝑡𝑑𝑡𝜎𝑑𝑊𝑡similar-to~𝑍~𝑇𝜈d\tilde{Z}(t)=\left(a\tilde{Z}(t)+\sigma^{2}\nabla\log p^{\nu}_{T-t}\left(% \tilde{Z}(t)\right)\right)dt+\sigma dW(t),\tilde{Z}(\tilde{T})\sim\nu,italic_d over~ start_ARG italic_Z end_ARG ( italic_t ) = ( italic_a over~ start_ARG italic_Z end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( over~ start_ARG italic_Z end_ARG ( italic_t ) ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , over~ start_ARG italic_Z end_ARG ( over~ start_ARG italic_T end_ARG ) ∼ italic_ν , (25)

and for l=0,1,,NT~𝑙01subscript𝑁~𝑇l=0,1,\ldots,N_{\tilde{T}}italic_l = 0 , 1 , … , italic_N start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT,

dZ^(t)=(aZ^(t)+σ2i=1kΛisTlh,θi(Z^(lh)))dt+σdW(t),t[lh,(l+1)h],formulae-sequence𝑑^𝑍𝑡𝑎^𝑍𝑡superscript𝜎2superscriptsubscript𝑖1𝑘subscriptΛ𝑖subscriptsuperscript𝑠𝑖𝑇𝑙superscript𝜃^𝑍𝑙𝑑𝑡𝜎𝑑𝑊𝑡𝑡𝑙𝑙1d\hat{Z}(t)=\left(a\hat{Z}(t)+\sigma^{2}\sum_{i=1}^{k}\Lambda_{i}s^{i}_{T-lh,% \theta^{*}}\left(\hat{Z}(lh)\right)\right)dt+\sigma dW(t),t\in[lh,(l+1)h],italic_d over^ start_ARG italic_Z end_ARG ( italic_t ) = ( italic_a over^ start_ARG italic_Z end_ARG ( italic_t ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_Z end_ARG ( italic_l italic_h ) ) ) italic_d italic_t + italic_σ italic_d italic_W ( italic_t ) , italic_t ∈ [ italic_l italic_h , ( italic_l + 1 ) italic_h ] , (26)

where hNT~=T~subscript𝑁~𝑇~𝑇hN_{\tilde{T}}=\tilde{T}italic_h italic_N start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT = over~ start_ARG italic_T end_ARG. Process (26) is the time-discretization version of process (25) without the initialization error (since Z~(0)Z^(0)similar-to~𝑍0^𝑍0\tilde{Z}(0)\sim\hat{Z}(0)over~ start_ARG italic_Z end_ARG ( 0 ) ∼ over^ start_ARG italic_Z end_ARG ( 0 )). We denote the law of process (25) and (26) as αT~subscript𝛼~𝑇\alpha_{\tilde{T}}italic_α start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT and βT~𝒫(C([0,T]:d)\beta_{\tilde{T}}\in\mathcal{P}(C([0,T]:\mathbb{R}^{d})italic_β start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT ∈ caligraphic_P ( italic_C ( [ 0 , italic_T ] : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ), respectively. For fixed 𝚲𝚲\boldsymbol{\Lambda}bold_Λ, we call Z~(T~)p𝚲T~similar-to~𝑍~𝑇subscriptsuperscript𝑝~𝑇𝚲\tilde{Z}(\tilde{T})\sim p^{\tilde{T}}_{\boldsymbol{\Lambda}}over~ start_ARG italic_Z end_ARG ( over~ start_ARG italic_T end_ARG ) ∼ italic_p start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_Λ end_POSTSUBSCRIPT and Z^(T~)p^𝚲T~similar-to^𝑍~𝑇subscriptsuperscript^𝑝~𝑇𝚲\hat{Z}(\tilde{T})\sim\hat{p}^{\tilde{T}}_{\boldsymbol{\Lambda}}over^ start_ARG italic_Z end_ARG ( over~ start_ARG italic_T end_ARG ) ∼ over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_Λ end_POSTSUBSCRIPT.

Proof.

From triangle inequality, we have

TV(ν,p^𝚲^)TV𝜈subscript^𝑝^𝚲\displaystyle\text{TV}\left(\nu,\hat{p}_{\hat{\boldsymbol{\Lambda}}}\right)TV ( italic_ν , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) TV(ν,p^𝚲^T~)+TV(p^𝚲^T~,p𝚲^T~)+TV(p𝚲^T~,p^𝚲^)absentTV𝜈subscriptsuperscript^𝑝~𝑇^𝚲TVsubscriptsuperscript^𝑝~𝑇^𝚲subscriptsuperscript𝑝~𝑇^𝚲TVsubscriptsuperscript𝑝~𝑇^𝚲subscript^𝑝^𝚲\displaystyle\leq\text{TV}\left(\nu,\hat{p}^{\tilde{T}}_{\hat{\boldsymbol{% \Lambda}}}\right)+\text{TV}\left(\hat{p}^{\tilde{T}}_{\hat{\boldsymbol{\Lambda% }}},p^{\tilde{T}}_{\hat{\boldsymbol{\Lambda}}}\right)+\text{TV}\left(p^{\tilde% {T}}_{\hat{\boldsymbol{\Lambda}}},\hat{p}_{\hat{\boldsymbol{\Lambda}}}\right)≤ TV ( italic_ν , over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) + TV ( over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , italic_p start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) + TV ( italic_p start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT )
=TV(ν,p^𝚲^T~)+TV(p^𝚲^T~,p𝚲^T~)+TV(p𝚲^,p^𝚲^)absentTV𝜈subscriptsuperscript^𝑝~𝑇^𝚲TVsubscriptsuperscript^𝑝~𝑇^𝚲subscriptsuperscript𝑝~𝑇^𝚲TVsubscript𝑝^𝚲subscript^𝑝^𝚲\displaystyle=\text{TV}\left(\nu,\hat{p}^{\tilde{T}}_{\hat{\boldsymbol{\Lambda% }}}\right)+\text{TV}\left(\hat{p}^{\tilde{T}}_{\hat{\boldsymbol{\Lambda}}},p^{% \tilde{T}}_{\hat{\boldsymbol{\Lambda}}}\right)+\text{TV}\left(p_{\hat{% \boldsymbol{\Lambda}}},\hat{p}_{\hat{\boldsymbol{\Lambda}}}\right)= TV ( italic_ν , over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) + TV ( over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , italic_p start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) + TV ( italic_p start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT )
TV(ν,p^𝚲^T~)+TV(p𝚲^,p^𝚲^).less-than-or-similar-toabsentTV𝜈subscriptsuperscript^𝑝~𝑇^𝚲TVsubscript𝑝^𝚲subscript^𝑝^𝚲\displaystyle\lesssim\text{TV}\left(\nu,\hat{p}^{\tilde{T}}_{\hat{\boldsymbol{% \Lambda}}}\right)+\text{TV}\left(p_{\hat{\boldsymbol{\Lambda}}},\hat{p}_{\hat{% \boldsymbol{\Lambda}}}\right).≲ TV ( italic_ν , over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) + TV ( italic_p start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) .

From Lemma 7, we bound the last term

TV(p𝚲^,p^𝚲^)exp(T)maxi=1,2,,kDKL(pTiπ)+kTσ(ϵscore+Ldh+LhM).less-than-or-similar-toTVsubscript𝑝^𝚲subscript^𝑝^𝚲𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋𝑘𝑇𝜎subscriptitalic-ϵscore𝐿𝑑𝐿𝑀\text{TV}\left(p_{\hat{\boldsymbol{\Lambda}}},\hat{p}_{\hat{\boldsymbol{% \Lambda}}}\right)\lesssim\exp(-T)\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left% (p^{i}_{T}\parallel\pi\right)}+\sqrt{kT}\sigma\left(\epsilon_{\text{score}}+L% \sqrt{dh}+Lh\sqrt{M}\right).TV ( italic_p start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) ≲ roman_exp ( - italic_T ) roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG + square-root start_ARG italic_k italic_T end_ARG italic_σ ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT + italic_L square-root start_ARG italic_d italic_h end_ARG + italic_L italic_h square-root start_ARG italic_M end_ARG ) .

To bound the first term, we use chain rule of KL divergence, Girsanov’s theorem, and an approximation argument similar as in Section C.2 to get

DKL(νp^𝚲^T~)subscript𝐷KLconditional𝜈subscriptsuperscript^𝑝~𝑇^𝚲\displaystyle D_{\text{KL}}\left(\nu\parallel\hat{p}^{\tilde{T}}_{\hat{% \boldsymbol{\Lambda}}}\right)italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) DKL(αT~βT~)less-than-or-similar-toabsentsubscript𝐷KLconditionalsubscript𝛼~𝑇subscript𝛽~𝑇\displaystyle\lesssim D_{\text{KL}}\left(\alpha_{\tilde{T}}\parallel\beta_{% \tilde{T}}\right)≲ italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_α start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT ∥ italic_β start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT )
1T~l=0NT~1𝔼αT~[lh(l+1)hσ2sTlh,θ𝚲(Z(lh))logpTtν(Z(t))22𝑑t]less-than-or-similar-toabsent1~𝑇superscriptsubscript𝑙0subscript𝑁~𝑇1subscript𝔼subscript𝛼~𝑇delimited-[]superscriptsubscript𝑙𝑙1superscript𝜎2superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑠𝚲𝑇𝑙superscript𝜃𝑍𝑙subscriptsuperscript𝑝𝜈𝑇𝑡𝑍𝑡22differential-d𝑡\displaystyle\lesssim\frac{1}{\tilde{T}}\sum_{l=0}^{N_{\tilde{T}}-1}\mathbb{E}% _{\alpha_{\tilde{T}}}\left[\int_{lh}^{(l+1)h}\sigma^{2}\left\lVert s^{% \boldsymbol{\Lambda}}_{T-lh,\theta^{*}}\left(Z(lh)\right)-\nabla\log p^{\nu}_{% T-t}\left(Z(t)\right)\right\rVert_{2}^{2}dt\right]≲ divide start_ARG 1 end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∑ start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT italic_l italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l + 1 ) italic_h end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_s start_POSTSUPERSCRIPT bold_Λ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_l italic_h , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_l italic_h ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_t ]
1T~0T~[σ2𝔼Z(t)ptν[i=1k(Λist,θi(Z(t)))logptν(Z(t))22]]𝑑tless-than-or-similar-toabsent1~𝑇superscriptsubscript0~𝑇delimited-[]superscript𝜎2subscript𝔼similar-to𝑍𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscriptΛ𝑖subscriptsuperscript𝑠𝑖𝑡superscript𝜃𝑍𝑡subscriptsuperscript𝑝𝜈𝑡𝑍𝑡22differential-d𝑡\displaystyle\lesssim\frac{1}{\tilde{T}}\int_{0}^{\tilde{T}}\left[\sigma^{2}% \mathbb{E}_{Z(t)\sim p^{\nu}_{t}}\left[\left\lVert\sum_{i=1}^{k}\left(\Lambda_% {i}s^{i}_{t,\theta^{*}}\left(Z(t)\right)\right)-\nabla\log p^{\nu}_{t}(Z(t))% \right\rVert_{2}^{2}\right]\right]dt≲ divide start_ARG 1 end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT [ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( roman_Λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ] italic_d italic_t
~(𝚲^;θ,σ2)=~(𝚲;θ,σ2)+[~(𝚲^;θ,σ2)~(𝚲;θ,σ2)]less-than-or-similar-toabsent~^𝚲superscript𝜃superscript𝜎2~superscript𝚲superscript𝜃superscript𝜎2delimited-[]~^𝚲superscript𝜃superscript𝜎2~superscript𝚲superscript𝜃superscript𝜎2\displaystyle\lesssim\tilde{\mathcal{L}}\left(\hat{\boldsymbol{\Lambda}};% \theta^{*},\sigma^{2}\right)=\tilde{\mathcal{L}}\left(\boldsymbol{\Lambda}^{*}% ;\theta^{*},\sigma^{2}\right)+\left[\tilde{\mathcal{L}}\left(\hat{\boldsymbol{% \Lambda}};\theta^{*},\sigma^{2}\right)-\tilde{\mathcal{L}}\left(\boldsymbol{% \Lambda}^{*};\theta^{*},\sigma^{2}\right)\right]≲ over~ start_ARG caligraphic_L end_ARG ( over^ start_ARG bold_Λ end_ARG ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) = over~ start_ARG caligraphic_L end_ARG ( bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + [ over~ start_ARG caligraphic_L end_ARG ( over^ start_ARG bold_Λ end_ARG ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) - over~ start_ARG caligraphic_L end_ARG ( bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ]
:=I1+I2,assignabsentsubscript𝐼1subscript𝐼2\displaystyle:=I_{1}+I_{2},:= italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ,

where I1subscript𝐼1I_{1}italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT represents the approximation error and I2subscript𝐼2I_{2}italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT represents the excess risk. Therefore, from McDiarmid’s inequality, for δ>0𝛿0\delta>0italic_δ > 0, with probability at least 1δ1𝛿1-\delta1 - italic_δ,

I2𝒪(σ2(log(1δ))1/2(NT~n)1/2)𝒪(σ2(log(1δ))1/2n1/2)less-than-or-similar-tosubscript𝐼2𝒪superscript𝜎2superscript1𝛿12superscriptsubscript𝑁~𝑇𝑛12less-than-or-similar-to𝒪superscript𝜎2superscript1𝛿12superscript𝑛12\displaystyle I_{2}\lesssim\mathcal{O}\left(\sigma^{2}\left(\log\left(\frac{1}% {\delta}\right)\right)^{-1/2}\left(N_{\tilde{T}}n\right)^{-1/2}\right)\lesssim% \mathcal{O}\left(\sigma^{2}\left(\log\left(\frac{1}{\delta}\right)\right)^{-1/% 2}n^{-1/2}\right)italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≲ caligraphic_O ( italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ( italic_N start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT italic_n ) start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ) ≲ caligraphic_O ( italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT )

since T~Tless-than-or-similar-to~𝑇𝑇\tilde{T}\lesssim Tover~ start_ARG italic_T end_ARG ≲ italic_T and NT~subscript𝑁~𝑇N_{\tilde{T}}italic_N start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT is small.

Finally, we need to give a bound on I1subscript𝐼1I_{1}italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. The intuition is that from continuity of a diffusion process, when T~~𝑇\tilde{T}over~ start_ARG italic_T end_ARG is small, then pT~νsubscriptsuperscript𝑝𝜈~𝑇p^{\nu}_{\tilde{T}}italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT and p0νsubscriptsuperscript𝑝𝜈0p^{\nu}_{0}italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT are similar. Therefore, the approximation error of the linear regression should be small, given Assumption 3.

Fix t[0,T~]𝑡0~𝑇t\in[0,\tilde{T}]italic_t ∈ [ 0 , over~ start_ARG italic_T end_ARG ], then since ν𝜈\nuitalic_ν has a Lipschitz density and the compactedness assumption, the loss \mathcal{L}caligraphic_L is

\displaystyle\mathcal{L}caligraphic_L =~(𝚲;θ,σ2)absent~superscript𝚲superscript𝜃superscript𝜎2\displaystyle=\tilde{\mathcal{L}}\left(\boldsymbol{\Lambda}^{*};\theta^{*},% \sigma^{2}\right)= over~ start_ARG caligraphic_L end_ARG ( bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ; italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=σ2T~0T~𝔼Z(t)ptν[i=1kΛist,θi(Z(t))logptν(Z(t))22]𝑑tabsentsuperscript𝜎2~𝑇superscriptsubscript0~𝑇subscript𝔼similar-to𝑍𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑠𝑖𝑡superscript𝜃𝑍𝑡subscriptsuperscript𝑝𝜈𝑡𝑍𝑡22differential-d𝑡\displaystyle=\frac{\sigma^{2}}{\tilde{T}}\int_{0}^{\tilde{T}}\mathbb{E}_{Z(t)% \sim p^{\nu}_{t}}\left[\left\lVert\sum_{i=1}^{k}\Lambda^{*}_{i}s^{i}_{t,\theta% ^{*}}\left(Z(t)\right)-\nabla\log p^{\nu}_{t}(Z(t))\right\rVert_{2}^{2}\right]dt= divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_d italic_t
σ2T~0T~𝔼Z(t)ptν[i=1kΛist,θi(Z(t))i=1kΛilogpti(Z(t))22]𝑑tless-than-or-similar-toabsentsuperscript𝜎2~𝑇superscriptsubscript0~𝑇subscript𝔼similar-to𝑍𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑠𝑖𝑡superscript𝜃𝑍𝑡superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑝𝑖𝑡𝑍𝑡22differential-d𝑡\displaystyle\lesssim\frac{\sigma^{2}}{\tilde{T}}\int_{0}^{\tilde{T}}\mathbb{E% }_{Z(t)\sim p^{\nu}_{t}}\left[\left\lVert\sum_{i=1}^{k}\Lambda^{*}_{i}s^{i}_{t% ,\theta^{*}}\left(Z(t)\right)-\sum_{i=1}^{k}\Lambda^{*}_{i}\nabla\log p^{i}_{t% }(Z(t))\right\rVert_{2}^{2}\right]dt≲ divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_s start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_d italic_t
+σ2T~0T~𝔼Z(t)ptν[i=1kΛilogpti(Z(t))logptν(Z(t))22]𝑑tsuperscript𝜎2~𝑇superscriptsubscript0~𝑇subscript𝔼similar-to𝑍𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑝𝑖𝑡𝑍𝑡subscriptsuperscript𝑝𝜈𝑡𝑍𝑡22differential-d𝑡\displaystyle+\frac{\sigma^{2}}{\tilde{T}}\int_{0}^{\tilde{T}}\mathbb{E}_{Z(t)% \sim p^{\nu}_{t}}\left[\left\lVert\sum_{i=1}^{k}\Lambda^{*}_{i}\nabla\log p^{i% }_{t}(Z(t))-\nabla\log p^{\nu}_{t}(Z(t))\right\rVert_{2}^{2}\right]dt+ divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_d italic_t
σ2kϵscore2+σ2𝔼Z(0)ν[i=1kΛilogp0i(Z(0))logp0ν(Z(0))22]dtless-than-or-similar-toabsentsuperscript𝜎2𝑘superscriptsubscriptitalic-ϵscore2superscript𝜎2subscript𝔼similar-to𝑍0𝜈delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑝𝑖0𝑍0subscriptsuperscript𝑝𝜈0𝑍022𝑑𝑡\displaystyle\lesssim\sigma^{2}k\epsilon_{\text{score}}^{2}+\sigma^{2}\mathbb{% E}_{Z(0)\sim\nu}\left[\left\lVert\sum_{i=1}^{k}\Lambda^{*}_{i}\nabla\log p^{i}% _{0}(Z(0))-\nabla\log p^{\nu}_{0}(Z(0))\right\rVert_{2}^{2}\right]dt≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( 0 ) ∼ italic_ν end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( 0 ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( 0 ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_d italic_t
+σ2T~0T~𝔼Z(t)ptν[i=1kΛilogpti(Z(t))i=1kΛilogp0i(Z(t))22]𝑑tsuperscript𝜎2~𝑇superscriptsubscript0~𝑇subscript𝔼similar-to𝑍𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑝𝑖𝑡𝑍𝑡superscriptsubscript𝑖1𝑘subscriptsuperscriptΛ𝑖subscriptsuperscript𝑝𝑖0𝑍𝑡22differential-d𝑡\displaystyle+\frac{\sigma^{2}}{\tilde{T}}\int_{0}^{\tilde{T}}\mathbb{E}_{Z(t)% \sim p^{\nu}_{t}}\left[\left\lVert\sum_{i=1}^{k}\Lambda^{*}_{i}\nabla\log p^{i% }_{t}(Z(t))-\sum_{i=1}^{k}\Lambda^{*}_{i}\nabla\log p^{i}_{0}(Z(t))\right% \rVert_{2}^{2}\right]dt+ divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_d italic_t
+σ2T~0T~𝔼Z(t)ptν[logptν(Z(t))logp0ν(Z(t))22]𝑑tsuperscript𝜎2~𝑇superscriptsubscript0~𝑇subscript𝔼similar-to𝑍𝑡subscriptsuperscript𝑝𝜈𝑡delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝜈𝑡𝑍𝑡subscriptsuperscript𝑝𝜈0𝑍𝑡22differential-d𝑡\displaystyle+\frac{\sigma^{2}}{\tilde{T}}\int_{0}^{\tilde{T}}\mathbb{E}_{Z(t)% \sim p^{\nu}_{t}}\left[\left\lVert\nabla\log p^{\nu}_{t}(Z(t))-\nabla\log p^{% \nu}_{0}(Z(t))\right\rVert_{2}^{2}\right]dt+ divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG over~ start_ARG italic_T end_ARG end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( italic_t ) ∼ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) - ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( italic_t ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_d italic_t
σ2kϵscore2+σ2𝔼Z(0)ν[p0ν(Z(0))p𝚲(Z(0))22]less-than-or-similar-toabsentsuperscript𝜎2𝑘superscriptsubscriptitalic-ϵscore2superscript𝜎2subscript𝔼similar-to𝑍0𝜈delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝜈0𝑍0subscript𝑝superscript𝚲𝑍022\displaystyle\lesssim\sigma^{2}k\epsilon_{\text{score}}^{2}+\sigma^{2}\mathbb{% E}_{Z(0)\sim\nu}\left[\left\lVert p^{\nu}_{0}(Z(0))-p_{\boldsymbol{\Lambda}^{*% }}(Z(0))\right\rVert_{2}^{2}\right]≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( 0 ) ∼ italic_ν end_POSTSUBSCRIPT [ ∥ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( 0 ) ) - italic_p start_POSTSUBSCRIPT bold_Λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_Z ( 0 ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
+σ2𝔼Z(T~)γT~d[pT~ν(Z(T~))p0ν(Z(T~))22]+σ2k𝔼Z(T~)γT~d[pT~j[1,k](Z(T~))p0j[1,k](Z(T~))22]superscript𝜎2subscript𝔼similar-to𝑍~𝑇subscriptsuperscript𝛾𝑑~𝑇delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝜈~𝑇𝑍~𝑇subscriptsuperscript𝑝𝜈0𝑍~𝑇22superscript𝜎2𝑘subscript𝔼similar-to𝑍~𝑇subscriptsuperscript𝛾𝑑~𝑇delimited-[]superscriptsubscriptdelimited-∥∥subscriptsuperscript𝑝𝑗1𝑘~𝑇𝑍~𝑇subscriptsuperscript𝑝𝑗1𝑘0𝑍~𝑇22\displaystyle+\sigma^{2}\mathbb{E}_{Z(\tilde{T})\sim\gamma^{d}_{\tilde{T}}}% \left[\left\lVert p^{\nu}_{\tilde{T}}(Z(\tilde{T}))-p^{\nu}_{0}(Z(\tilde{T}))% \right\rVert_{2}^{2}\right]+\sigma^{2}k\mathbb{E}_{Z(\tilde{T})\sim\gamma^{d}_% {\tilde{T}}}\left[\left\lVert p^{j\in[1,k]}_{\tilde{T}}(Z(\tilde{T}))-p^{j\in[% 1,k]}_{0}(Z(\tilde{T}))\right\rVert_{2}^{2}\right]+ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Z ( over~ start_ARG italic_T end_ARG ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT ( italic_Z ( over~ start_ARG italic_T end_ARG ) ) - italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( over~ start_ARG italic_T end_ARG ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k blackboard_E start_POSTSUBSCRIPT italic_Z ( over~ start_ARG italic_T end_ARG ) ∼ italic_γ start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_p start_POSTSUPERSCRIPT italic_j ∈ [ 1 , italic_k ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT ( italic_Z ( over~ start_ARG italic_T end_ARG ) ) - italic_p start_POSTSUPERSCRIPT italic_j ∈ [ 1 , italic_k ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_Z ( over~ start_ARG italic_T end_ARG ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
σ2kϵscore2+σ2DKL(νp𝚲)+σ2kDKL(pT~νp0ν)less-than-or-similar-toabsentsuperscript𝜎2𝑘superscriptsubscriptitalic-ϵscore2superscript𝜎2subscript𝐷KLconditional𝜈superscriptsubscript𝑝𝚲superscript𝜎2𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝜈~𝑇subscriptsuperscript𝑝𝜈0\displaystyle\lesssim\sigma^{2}k\epsilon_{\text{score}}^{2}+\sigma^{2}D_{\text% {KL}}\left(\nu\parallel p_{\boldsymbol{\Lambda}}^{*}\right)+\sigma^{2}kD_{% \text{KL}}\left(p^{\nu}_{\tilde{T}}\parallel p^{\nu}_{0}\right)≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ italic_p start_POSTSUBSCRIPT bold_Λ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUBSCRIPT ∥ italic_p start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
σ2kϵscore2+σ2ϵ12+σ2k𝒪((T~)1/2).less-than-or-similar-toabsentsuperscript𝜎2𝑘superscriptsubscriptitalic-ϵscore2superscript𝜎2superscriptsubscriptitalic-ϵ12superscript𝜎2𝑘𝒪superscript~𝑇12\displaystyle\lesssim\sigma^{2}k\epsilon_{\text{score}}^{2}+\sigma^{2}\epsilon% _{1}^{2}+\sigma^{2}k\mathcal{O}\left(\left(\tilde{T}\right)^{1/2}\right).≲ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k caligraphic_O ( ( over~ start_ARG italic_T end_ARG ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ) .

Therefore, from Pinsker’s inequality, with probability at least 1δ1𝛿1-\delta1 - italic_δ,

TV(ν,p^𝚲^)TV𝜈subscript^𝑝^𝚲\displaystyle\text{TV}\left(\nu,\hat{p}_{\hat{\boldsymbol{\Lambda}}}\right)TV ( italic_ν , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) TV(ν,p^𝚲^T~)+TV(p𝚲^,p^𝚲^)less-than-or-similar-toabsentTV𝜈subscriptsuperscript^𝑝~𝑇^𝚲TVsubscript𝑝^𝚲subscript^𝑝^𝚲\displaystyle\lesssim\text{TV}\left(\nu,\hat{p}^{\tilde{T}}_{\hat{\boldsymbol{% \Lambda}}}\right)+\text{TV}\left(p_{\hat{\boldsymbol{\Lambda}}},\hat{p}_{\hat{% \boldsymbol{\Lambda}}}\right)≲ TV ( italic_ν , over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) + TV ( italic_p start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT )
DKL(νp^𝚲^T~)+TV(p𝚲^,p^𝚲^)less-than-or-similar-toabsentsubscript𝐷KLconditional𝜈subscriptsuperscript^𝑝~𝑇^𝚲TVsubscript𝑝^𝚲subscript^𝑝^𝚲\displaystyle\lesssim\sqrt{D_{\text{KL}}\left(\nu\parallel\hat{p}^{\tilde{T}}_% {\hat{\boldsymbol{\Lambda}}}\right)}+\text{TV}\left(p_{\hat{\boldsymbol{% \Lambda}}},\hat{p}_{\hat{\boldsymbol{\Lambda}}}\right)≲ square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_ν ∥ over^ start_ARG italic_p end_ARG start_POSTSUPERSCRIPT over~ start_ARG italic_T end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT ) end_ARG + TV ( italic_p start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT , over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT over^ start_ARG bold_Λ end_ARG end_POSTSUBSCRIPT )
σϵ1+σk𝒪(T~1/4)+𝒪(σ(log(1δ))1/4n1/4)less-than-or-similar-toabsent𝜎subscriptitalic-ϵ1𝜎𝑘𝒪superscript~𝑇14𝒪𝜎superscript1𝛿14superscript𝑛14\displaystyle\lesssim\sigma\epsilon_{1}+\sigma\sqrt{k}\mathcal{O}\left(\tilde{% T}^{1/4}\right)+\mathcal{O}\left(\sigma\left(\log\left(\frac{1}{\delta}\right)% \right)^{-1/4}n^{-1/4}\right)≲ italic_σ italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_σ square-root start_ARG italic_k end_ARG caligraphic_O ( over~ start_ARG italic_T end_ARG start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT ) + caligraphic_O ( italic_σ ( roman_log ( divide start_ARG 1 end_ARG start_ARG italic_δ end_ARG ) ) start_POSTSUPERSCRIPT - 1 / 4 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - 1 / 4 end_POSTSUPERSCRIPT )
+exp(T)maxi=1,2,,kDKL(pTiπ)+σkT(ϵscore+Ldh+LhM),𝑇subscript𝑖12𝑘subscript𝐷KLconditionalsubscriptsuperscript𝑝𝑖𝑇𝜋𝜎𝑘𝑇subscriptitalic-ϵscore𝐿𝑑𝐿𝑀\displaystyle+\exp(-T)\max_{i=1,2,\ldots,k}\sqrt{D_{\text{KL}}\left(p^{i}_{T}% \parallel\pi\right)}+\sigma\sqrt{kT}\left(\epsilon_{\text{score}}+L\sqrt{dh}+% Lh\sqrt{M}\right),+ roman_exp ( - italic_T ) roman_max start_POSTSUBSCRIPT italic_i = 1 , 2 , … , italic_k end_POSTSUBSCRIPT square-root start_ARG italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∥ italic_π ) end_ARG + italic_σ square-root start_ARG italic_k italic_T end_ARG ( italic_ϵ start_POSTSUBSCRIPT score end_POSTSUBSCRIPT + italic_L square-root start_ARG italic_d italic_h end_ARG + italic_L italic_h square-root start_ARG italic_M end_ARG ) ,

which finishes the proof.    

Appendix D Experiment details

D.1 Training and architecture details

To standardize comparison, the baseline and the auxiliary score models are parametrized by the exact same UNet architecture; the only difference between a baseline and an auxiliary is the amount of training data they have access to. The Python classes in our supplementary codebase, model_1D.ScoreNet and model_EMNIST.ScoreNet, are both modified from the ScoreNet class given in the GitHub repository of Song et al. [47]. One caveat is that to accommodate the one-dimensional data in Section 5.1, we changed the stride and kernel size of the convolutional layers in model_1D.ScoreNet to be 1111. The one-dimensional UNet has 344k344𝑘344k344 italic_k trainable parameters; the EMNIST UNet triples the trainable parameters count to 1.11.11.11.1 millons. ScoreFusion models has only k𝑘kitalic_k trainable parameters where k𝑘kitalic_k is the number of auxiliary scores.

We follow the standard machine learning convention of splitting each dataset into train, validation, and test sets with stratified sampling to ensure class balance. The ratio of training data to validation data is 4:1:414:14 : 1. We use the ground truth digit labels only for data-splitting, hiding them from the model during training. Model training taking more than an hour was run on two NVIDIA A40 GPUs in a computing cluster, while lightweight tasks are run on Google Colab using an A100 GPU.

Model checkpoints corresponding to all our experiments, both for the pre-trained auxiliary score models and the baseline models, can be found in the subdirectory ckpt in the .zip file.

D.2 Section 5.1 supplementary data

Due to space limit, we cannot fit all the data columns into Table 1. We attach in Table 4 the complete data table for 1-Wasserstein distances from each learned distribution to the ground truth distribution when the training size varies. Standard error is calculated from the 1-Wasserstein distance of 10101010 batch-pairs of 8096809680968096 random samples drawn independently from the ground truth and a trained generative model. We note that there exists randomness in fitting 𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT as a result of Stochastic Gradient Descent.

Table 4: 1-Wasserstein distance from the ground truth Gaussian mixture
  Model 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT
Baseline 106.93±1.43plus-or-minus106.931.43106.93\pm 1.43106.93 ± 1.43 13.46±0.28plus-or-minus13.460.2813.46\pm 0.2813.46 ± 0.28 16.74±0.27plus-or-minus16.740.2716.74\pm 0.2716.74 ± 0.27
ScoreFusion 0.39±0.02plus-or-minus0.390.02\mathbf{0.39\pm 0.02}bold_0.39 ± bold_0.02 0.51±0.03plus-or-minus0.510.03\mathbf{0.51\pm 0.03}bold_0.51 ± bold_0.03 0.36±0.02plus-or-minus0.360.02\mathbf{0.36\pm 0.02}bold_0.36 ± bold_0.02
𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT of ScoreFusion [0.62,0.38]0.620.38[0.62,0.38][ 0.62 , 0.38 ] [0.65,0.35]0.650.35[0.65,0.35][ 0.65 , 0.35 ] [0.46,0.54]0.460.54[0.46,0.54][ 0.46 , 0.54 ]
 
  Model 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT
Baseline 2.13±0.12plus-or-minus2.130.122.13\pm 0.122.13 ± 0.12 0.55±0.04plus-or-minus0.550.040.55\pm 0.040.55 ± 0.04 0.15±0.02plus-or-minus0.150.02\mathbf{0.15\pm 0.02}bold_0.15 ± bold_0.02
ScoreFusion 0.58±0.03plus-or-minus0.580.03\mathbf{0.58\pm 0.03}bold_0.58 ± bold_0.03 0.38±0.02plus-or-minus0.380.02\mathbf{0.38\pm 0.02}bold_0.38 ± bold_0.02 0.30±0.02plus-or-minus0.300.020.30\pm 0.020.30 ± 0.02
𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT of ScoreFusion [0.68,0.32]0.680.32[0.68,0.32][ 0.68 , 0.32 ] [0.61,0.39]0.610.39[0.61,0.39][ 0.61 , 0.39 ] [0.58,0.42]0.580.42[0.58,0.42][ 0.58 , 0.42 ]
 

Additional histograms of the distributions learned by ScoreFusion versus the baseline are attached:

Refer to caption
Refer to caption
Figure 4: Left: Models trained on 32323232 samples. Right: Models trained on 64646464 samples.
Refer to caption
Refer to caption
Figure 5: Left: Models trained on 128128128128 samples. Right: Models trained on 256256256256 samples.
Refer to caption
Refer to caption
Figure 6: Left: Models trained on 512512512512 samples. Right: Models trained on 1024102410241024 samples.

D.3 Section 5.2 supplementary data

We also provide supplementary data for the experiments on handwritten EMNIST digits. Table 5 gives the empirical distribution of the digits sampled unconditionally from the auxiliary scores.

Table 5: Digits percentage of 1024 images sampled from the auxiliary score models without fusion. Classified by SpinalNet.
  Auxiliary Score 0 1 2 3 4 5 6 7 8 9
1 0.1% 0.1% 0.6% 0.6% 1.1% 0.3% 0.0% 18.7% 0.2% 78.2%
2 0.1% 0.1% 0.3% 0.8% 1.1% 0.5% 0.0% 41.1% 0.2% 55.8%
3 0.0% 0.2% 0.7% 0.7% 1.2% 0.8% 0.0% 72.1% 0.6% 23.7%
4 0.1% 0.5% 0.7% 0.5% 0.9% 0.4% 0.1% 87.9% 0.3% 8.6%
Target Distribution 60% 40%
 
Table 6: Full version of Table 3, Part I. Digits distribution estimated by SpinalNet. Bolded columns are the breakdown for ScoreFusion. “Others” category refers to fraction of samples resembling digits other than the 7’s or 9’s more.
  Digit True 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT
Baseline Fusion Baseline Fusion Baseline Fusion Baseline Fusion
7 60% 47.9% 55.6% 57.9% 55.5% 66.8% 57.5% 64.8% 58.2%
9 40% 10.3% 39.4% 12.8% 41.7% 23.8% 38.0% 28.3% 38.7%
Others 0 41.8% 5.0% 29.3% 2.8% 9.4% 4.5% 6.9% 3.1%
 
Table 7: Full version of Table 3, Part II. Digits distribution estimated by SpinalNet. Bolded columns are the breakdown for ScoreFusion. “Others” category refers to fraction of samples resembling digits other than the 7’s or 9’s more.
  Digit True 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
Baseline Fusion Baseline Fusion Baseline Fusion
7 60% 65.5% 55.6% 66.7% 59.8% 67.4% 59.7%
9 40% 26.7% 39.8% 27.9% 36.7% 29.0% 37.4%
Others 0 7.8% 3.6% 5.4% 3.5% 3.6% 2.9%
 
Table 8: Optimal weights 𝝀superscript𝝀\boldsymbol{\lambda}^{*}bold_italic_λ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT corresponding to the ScoreFusion models whose NLL test losses we reported in Table 2. Each column is a weight vector that parameterizes the ScoreFusion model trained with 2jsuperscript2𝑗2^{j}2 start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT data.
  𝝀isubscript𝝀𝑖\boldsymbol{\lambda}_{i}bold_italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
i=1𝑖1i=1italic_i = 1 0.199 0.187 0.182 0.181 0.167 0.183 0.176
i=2𝑖2i=2italic_i = 2 0.305 0.326 0.328 0.319 0.345 0.311 0.310
i=3𝑖3i=3italic_i = 3 0.279 0.267 0.284 0.285 0.319 0.294 0.295
i=4𝑖4i=4italic_i = 4 0.217 0.220 0.206 0.216 0.170 0.213 0.220