Generative Modeling with Phase Stochastic Bridges
Abstract
We introduce a novel generative modeling framework grounded in phase space dynamics, taking inspiration from the principles underlying Critically damped Langevin Dynamics and Bridge Matching. Leveraging insights from Stochastic Optimal Control, we construct a more favorable path measure in the phase space that is highly advantageous for efficient sampling. A distinctive feature of our approach is the early-stage data prediction capability within the context of propagating generative Ordinary Differential Equations or Stochastic Differential Equations. This early prediction, enabled by the model’s unique structural characteristics, sets the stage for more efficient data generation, leveraging additional velocity information along the trajectory. This innovation has spurred the exploration of a novel avenue for mitigating sampling complexity by quickly converging to realistic data samples. Our model yields comparable results in image generation and notably outperforms baseline methods, particularly when faced with a limited Number of Function Evaluations. Furthermore, our approach rivals the performance of diffusion models equipped with efficient sampling techniques, underscoring its potential in the realm of generative modeling. Code is available at https://github.com/apple/ml-agm.
1 Introduction
Diffusion Models (DMs;Song et al. (2020a); Ho et al. (2020)) constitute an instrumental technique in generative modeling, which formulate a particular Stochastic Differential Equation (SDE) linking the data distribution with a tractable prior distribution. Initially, a DM diffuses data towards the prior distribution via a predetermined linear SDE. In order to reverse the process, a neural network is used to approximate the score function which is analytically available. Subsequently, the approximated score is utilized to conduct time reversal (Anderson, 1982; Haussmann & Pardoux, 1986) of this diffusion process, ultimately generating data. Recently, the Critical-damped Langevin Dynamics (CLD;Dockhorn et al. (2021)) extends the SDE framework of DM into phase space (whereas DMs operate in the position space) by introducing an auxiliary velocity variable, which is defined by tractable Gaussian distributions at the initial and terminal time steps. This augmentation induces a trajectory in position space exhibiting enhanced smoothness, as stochasticity is solely introduced into the velocity space. The distinctive structure of CLD is shown to enhance the empirical performance and sample efficiency. However, despite the success of CLD, inefficient sampling still persists due to unnecessary curvature of the dynamics (Fig.1) as it has to converge to equilibrium for sampling from the tractable prior.
The remarkable accomplishments of DM have also catalyzed recent advancements in generative modeling, leading to the development of Bridge Matching (BM;(Peluchetti, 2021; Liu et al., 2022; 2023)) and Flow Matching (FM;models(Lipman et al., 2022)). These models leverage dynamic transport maps underpinned by the utilization of SDEs or ODEs. Unlike DM, Bridge and Flow Matching relaxes the reliance on a forward diffusion process with an asymptotic convergence to a prior distribution over an infinite time horizon. Moreover, they exhibit a heightened degree of versatility, enabling the construction of transport maps between two arbitrary distributions by drawing upon insights from domains such as optimal transport (Pooladian et al., 2023), normalizing flow (Tong et al., 2023b), and optimal control (Liu et al., 2023).
In this paper, we focus on enhancing the sample efficiency of velocity based generative modeling (eg, CLD) by utilizing the Stochastic Optimal Control (SOC) theory. Specifically, we leverage the outcomes of stochastic bridge within the context of linear momentum systems (Chen & Georgiou, 2015) to construct a path measure bridging the data and prior distribution. The resulting path exhibits a more straight position and velocity trajectory compared to CLD (fig.1), making it more amenable to efficient sampling. Within the broader landscape of dynamic generative modeling (ie, ODE/SDE based generative models), data point can often be represented as linear combinations of scaled intermediate data of dynamics and Gaussian noise. In our work, we re-establish this property, enabling the estimation of target data points by leveraging both state and velocity information. In the case of DM and FM, the estimation of target data is exclusively reliant on position information, whereas our method incorporates the additional dimension of velocity data, enhancing the precision and comprehensiveness of our estimations. It is also worth noting that our model exhibits the capacity to generate high fidelity images at early time steps (fig.2). In addition, we propose a sampling technique which demonstrates competitive results with small Number of Function Evaluations (NFEs), eg, 5 to 10. Table.1 demonstrates the design differences among aforementioned models. In summary, our paper presents the following contributions:
-
1.
We propose Acceleration Generative Modeling (AGM) which is built on the SOC theory, enabling the favorable trajectories for efficient sampling over 2nd-order momentum dynamics generative modeling such as CLD.
-
2.
As a result of AGM structural characteristics, it becomes possible to estimate a realistic data point at an early time point, a concept we refer to as sampling-hop. This approach not only yields a significant reduction in sampling complexity but also offers a novel perspective on accelerating the sampling in generative modeling by leveraging additional information from the dynamics.
-
3.
We achieve competitive results compared to DM approaches equipped with specifically designed fast sampling techniques on image datasets, particularly in small NFE settings.
2 Preliminary
Notation: Let and denote the -dimensional position and velocity variable of a particle at time . We denote the discretized time series as . The Wiener Process is denoted as . The identity matrix is denoted as . We define as the covariance matrix of and at time step .
2.1 Dynamical Generative Modeling
The generative modeling approaches rooted in dynamical systems, including ODE and SDE, have garnered significant attention. Here, we present three noteworthy dynamical generative models: Diffusion Model (DM), Flow Matching (FM) and Bridge Matching (BM).
Models | DM/FM | CLD | AGM(ours) |
---|---|---|---|
Diffusion Model: In the framework of DM, given drawn from a data distribution , the model proceeds to construct a SDE,
(1) |
whose terminal distributions at approach an approximate Gaussian, i.e. . This accomplishment is realized through the careful selection of the diffusion coefficient and the base drift . It is noteworthy that the time-reversal (Anderson, 1982) of (1) results in another SDE:
(2) |
where is the marginal density of (1) at time and is known as the score function. SDE (2) can be regarded as the time-reversal of (1) in such a manner that the path-wise measure is almost surely equivalent to the one induced by (1). As a consequence, these two SDEs share identical marginal over time. In practice, it is feasible to analytically sample given and . Additionally, we can leverage a neural network to learn the score function by regressing scaled Stein Score for the purpose of propagating (2). This learned score can then be integrated into the solution of the aforementioned SDE(2) to simulate the generation of data that adheres to the target data distribution from the prior distribution. Meanwhile, (2) also corresponds to an ODE which shares the same path-wise measure:
(3) |
which motivates the popular sampler introduced in (Zhang & Chen, 2022; Zhang et al., 2022; Bao et al., 2022) to solve the ODE (2) efficiently.
Bridge Matching and Flow Matching: An alternative approach to exploring the time-reversal of a forward noising process involves the concept of ’building bridges’ between two distinct distributions and . This method entails the learning of a mimicking diffusion process, commonly referred to as bridge matching, as elucidated in previous works (Peluchetti, 2021; Shi et al., 2022). Here we consider the SDE in the form of:
(4) |
which is pinned down at an initial and terminal point which are independently samples from predefined and . This is commonly known as the reciprocal projection of and in the literature (Shi et al., 2023; Peluchetti, 2023; Liu et al., 2022; Léonard et al., 2014). The construction of such SDE is accomplished by meticulous design of . A widely adopted choice for is , which induces the well-known Brownian Bridge (Liu et al., 2023; Somnath et al., 2023). Similar to the approach in DM and owing to the linear structure of the dynamics, one can efficiently estimate this drift by employing a neural network parameterized by weights for regression on: given and . As extensively discussed in previous studies (Liu et al., 2023; Shi et al., 2022), this bridge matching framework takes on the characteristics of FM (Lipman et al., 2022) when the diffusion coefficient tends to zero.
Remark 1.
The practice of constraining a stochastic process to specific initial and terminal conditions is a well-established setup in SOC. For a gentle introduction of it’s connection with Brownian Bridge, Schrödinger Bridge please see Appendix.C. From this perspective, one can derive Brownian Bridge, as elaborated in Appendix.D.1 for comprehensive elucidation. It is imperative to note that the SOC framework will serve as the fundamental basis upon which we will develop our algorithm.
3 Acceleration Generative Model
We apply SOC to characterize the twisted trajectory of momentum dynamics induced by CLD(Dockhorn et al., 2021). It becomes evident that the mechanisms encompassing flow matching, diffusion modeling, and Bridge matching collectively facilitate the construction of an estimated target data point, denoted as , by utilizing the intermediate state of the dynamics, . Our additional objective is to expedite the estimation of a plausible by incorporating additional dynamics-related information, such as velocity, thereby curtailing the requisite time integration.
In this section, we introduce the proposed method, termed as the Acceleration Generative Model (AGM), rooted in SOC theory. Building upon (Chen & Georgiou, 2015), we extend the framework by incorporating a time-varying diffusion coefficient and accommodating arbitrary boundary conditions, ultimately arriving at an analytical solution suited for the generative modeling. We demonstrate its efficacy in rectifying the trajectory of CLD, concurrently showcasing its aptitude for accurately estimating the target data at an early timestep , thereby enabling expeditious sampling.
As suggested by BM approach, there is a necessity to formulate a trajectory that bridges the two data points sampled from and respectively. Desirably, the intermediate trajectory should exhibit optimal characteristics that facilitate smoothness and linearity. This is essential for the ease of simulating the dynamics system to obtain the solution. In our endeavor to tackle this challenge and enhance the estimation of the data point by incorporating velocity components, we encapsulate the problem within a SOC framework, specifically formulated in the phase space which reads:
Definition 2 (Stochastic Bridge problem of linear momentum system (Chen & Georgiou, 2015)).
(5) | ||||
In this context, the matrix is recognized as the terminal cost matrix, serving to assess the proximity between the propagated and the ground truth at the terminal time . As the parameter approaches positive infinity, the trajectory converges toward the state , prompting a transition to constrained dynamics wherein the system becomes constrained by two predetermined boundaries, namely and . This configuration aligns seamlessly with the principles of constructing a feasible bridge, as advocated by the tenets of BM. It is worth noting that this interpolation approach essentially represents a natural extension (Chen & Georgiou, 2015) of the well-established concept of the Brownian Bridge (Revuz & Yor, 2013), which has been employed in trajectory inference (Somnath et al., 2023; Tong et al., 2023a) and image inpainting tasks (Liu et al., 2023) and its connection with Diffusion has been discussed in Liu et al. (2023). Indeed, it is evident that the target velocity lacks a precise definition within this problem, allowing for flexibility in the design space for our approach. To address this, we opt for the linear interpolation of the intermediate point and the target point, represented as , as the chosen terminal velocity, which also is the optimal control in the original space (see Appendix..D.1). This choice is made due to its ability to construct a trajectory characterized by straightness. Conceptually, the acceleration continually guides the dynamics towards the linear interpolation of the two data points, serving to mitigate the impact of introduced stochasticity. In contrast to previous bridge matching frameworks, the velocity’s boundary condition in our approach varies over time since it depends on the state and . The velocity variable serves solely as an auxiliary component aimed at straightening the trajectories. Regarding this SOC problem formulation, the solution is,
Proposition 3 (Phase Space Brownian Bridge).
When , The solution w.r.t optimization problem 5 is,
(6) |
Proof.
Please see Appendix.D.2. ∎
Remark 4.
denotes the second diagonal component in the matrix , a solution derived from the Lyapunov equation (see Lemma.9), serving as an implicit representation of the optimality of the control. This value is dependent upon the uncontrolled dynamics, where is set to the zero vector in (5), and will vary accordingly when uncontrolled dynamics change.
3.1 Training
By plugging the optimal control (6) back to the dynamics (5), we can obtain the desired SDE. As been suggested by (Song et al., 2020b; Dockhorn et al., 2021), such SDE has a corresponding probablistic ODE which shares the same marginal over time in which the drift term will have an additional score term . Here we summarize the force term for SDE and ODE as:
(7) | ||||
Henceforth, we refer to the dynamics associated with the Bridge Matching SDE as AGM-SDE, and its corresponding ODE counterpart as AGM-ODE. Meanwhile, the linearity of the system implies the intermediate state and the close form solution of score term are analytically available. In particular, the mean and covariance matrix of the intermediate marginal of such a system can be analytically computed with , and , provided we have the boundary conditions and in place, as outlined in Särkkä & Solin (2019). Please see Appendix.D.3 for detail. In order to sample from such multi-variant Gaussian, one need to decompose the covariance matrix by Cholesky decomposition, and is reparamertized as:
(8) |
where , and .
Parameterization: The Force term can be represented as a composite of the data point and Gaussian noise. Specifically,
(9) |
We express the force term as . Here, assumes the role of regulating the output of the network , ensuring that the variance of the network output is normalized to unity. For the detailed formulation of the normalizer , please refer to Appendix.D.8. In a manner similar to the BM approach, one can formulate the objective function for regressing the force term as follows:
(10) |
Where is known as the reweight of the objective function across the time horizon. We defer the derivation of and the presentation of , and in Appendix.D.
3.2 Sampling from AGM
Once the paramterized force term is trained, we are ready to simulate the dynamics to generate the samples by plugging it back to the dynamics (7). One can use any type of SDE or ODE sampler to propagate the learnt system. Here we list our choice of sampler for AGM-SDE and AGM-ODE.
Stochastic Sampler: To simulate the SDE, prior works are majorly relying on Euler-Maruyama(EM) (Kloeden et al., 1992) and related methods. We adopt the Symmetric Splitting Sampler(SSS) from Dockhorn et al. (2021) in our AGM-SDE. This selection is based on the compelling performance it offers when dealing with momentum systems.
Deterministic Sampler: It is imperative to acknowledge that this system is inherently underactuated because the force term is exclusively injected into the velocity component, while velocity serves as the driving factor for the position—a variable of primary interest in generative modeling context. More specifically, at time step , the impact of force does not immediately manifest in the position but rather takes effect at a subsequent time step, denoted as after discretizing time horizon. At time , it becomes undesirable to propagate the state using an initially uncontrolled velocity over an extended time interval . The presence of this delay phenomenon can also exert an influence when the time interval is large, thereby impeding our ability to reduce the NFE during sampling. We propose the adoption of an Exponential Integrator (EI) approach, as elaborated in Zhang & Chen (2022). Empirical evidence suggests that this method aligns well with our model. We provide an illustrative example of how the AGM-ODE, in conjunction with the EI technique, can be employed to inject the learnt network into both velocity and position channels simultaneously:
(11) | ||||
In Eq.11, denotes the transition matrix for our system, while represents the order multistep coefficient (Hochbruck & Ostermann, 2010). For a comprehensive derivation of these terms, please refer to Appendix.D.9. It is worth noting that the map** of into both the position and velocity channels significantly emulates the errors introduced by discretization delays. Sampling-hop: In the context of CLD (Dockhorn et al., 2021), their focus is on estimating the score function w.r.t. velocity, which essentially corresponds to estimating scaled in our notation. However, relying solely on the aforementioned information is not sufficient for estimating the data point . Additional knowledge regarding is also required in order to perform such estimation. In our case, the training objective implicitly includes both and (see eq.9), hence one can manage to recover by Proposition.5. Remarkably, our observations have unveiled that when the network is equipped with additional velocity information, it acquires the capability to estimate the target data point during the early stages of the trajectory, as illustrated in fig.2. This estimation can be seamlessly integrated into AGM-SDE and AGM-ODE and we name it sampling-hop. Specifically,
Proposition 5 (Sampling-Hop).
Given the state, velocity and trained force term at time step in sampling phase, The estimated data point can be represented as
(12) |
for AGM-SDE and AGM-ODE dynamics respectively, and ,.
Proof.
See Appendix.D.10 ∎
This property empowers us to allocate the NFE budget selectively within the time interval , where , effectively reducing the discretization error while maintaining the sampling quality. This insight paves the way for efficient low NFE sampling strategies later. Here we summarized the training and sampling procedure of our method in Algorithm.1 and Algorithm.2 respectively. 1: Input: data distribution 2: while not converge do 3: , 4: Compute mean and covariance and . (Appendix.D.3) 5: Sample .(eq.8) 6: Compute target (eq.7) using optimal acceleration (eq.9) 7: Compute loss (eq.10). 8: Take gradient descent with respect to . 9: end while 1: Input: trained , discretized time step [,,], Choose the sampler from [SSS(SDE), EI(ODE)]. Choose prior mean and covariance , 2: Sample . 3: for n = to do 4: estimate 5: 6: reconstruct using Proposition.5. 7: end for 8: Return
4 Experimental Results
Architectures and Hyperparameters: We parameterize using modified NCSN++ model as provided in Karras et al. (2022). We employ six input channels, accounting for both position and velocity variables, as opposed to the standard three channels used in the CIFAR-10 (Krizhevsky et al., 2009), AFHQv2 (Choi et al., 2020) and ImageNet (Deng et al., 2009) which leads to a negligible increase of network parameters. For the purpose of comparison with CLD in the toy dataset, we adopt the same ResNet-based architecture utilized in CLD. Throughout all of our experiments, we maintain a monotonically decreasing diffusion coefficient, given by . For the detailed experimental setup, please refer further to Appendix.E.
Evaluation: To assess the performance and the sampling speed of various algorithms, we employ the Fréchet Inception Distance score (FID;Heusel et al. (2017)) and the Number of Function Evaluations (NFE) as our metrics. For FID evaluation, we utilize reference statistics of all datasets obtained from EDM (Karras et al., 2022) and use 50k generated samples to evaluate. Additionally, we re-evaluate the FID of CLD and EDM using the same reference statistics to ensure consistency in our comparisons. For all other reported values, we directly source them from respective referenced papers.
Selection of : The choice of initial covariance directly influences the path measure of the trajectory. In our case, we set with hyperparameter . We observe that trajectories tend to exhibit pronounced curvature under specific conditions: when the is positive, the absolute value of the position is large. This behavior is particularly noticeable when dealing with images, where the data scale ranges from -1 to 1. We aim for favorable uncontrolled dynamics, as this can potentially lead to better-controlled dynamics. Our strategy is to design in such a way that the marginal distribution of uncontrolled dynamics at effectively covers the range of image data values meanwhile keeps negative. We can express the marginal of uncontrolled dynamics by leveraging the transition matrix , which gives us . Figure 3 illustrates the standard deviation of for various values of . Based on our empirical observations, we choose for all experiments, as it effectively covers the data range. The subsequent controlled dynamics (eq.7) will be constructed based on such desired uncontrolled dynamics as established.
NFE | CLD-SDE | AGM-SDE |
---|---|---|
20 | 100 | 7.9 |
50 | 19.93 | 3.21 |
150 | 2.99 | 2.68 |
1000 | 2.44 | 2.46 |
Stochastic Sampling: In experiments, we emphasize the advantages of using the AGM-SDE compared with CLD. Firstly, we show that our model exhibits superior performance when NFE is significantly lower than that of CLD, particularly in toy dataset scenarios. For evaluation, we utilized the multi-modal Mixture of Gaussian and Multi-Swiss-Roll datasets. The results obtained from the toy dataset, as shown in Fig.8, demonstrate that AGM-SDE is capable of generating data that closely aligns with the ground truth, while requiring NFE that is around one order of magnitude lower than CLD. Furthermore, our findings reveal that AGM-SDE outperforms CLD in the context of CIFAR-10 image generation tasks, especially when faced with limited NFE, as illustrated in Table 2.
Deterministic Sampling: We validate our algorithm on high-dimensional image generation with a deterministic sampler. We provide uncurated samples from CIFAR-10, AFHQv2 and ImageNet-64 with varying NFE in Appendix.H. Regarding the quantitative evaluation, Table.4 and Table.4 summarize the FID together with NFE used for sampling on CIFAR-10 and ImageNet-64. Notably, AGM-ODE achieves 2.46 FID score with 50 NFE on CIFAR-10, and 10.55 FID score with 20 NFE in unconditional ImageNet-64 which is comparable to the existing dynamical generative modeling.
We underscore the effectiveness of sampling-hop, especially when faced with a constrained NFE budget, in comparison to baselines. We validate it on the CIFAR-10 and AFHQv2 dataset respectively. Fig.4 illustrates that AGM-ODE is able to generate plausible images even when NFE and outperforms EDM(Karras et al., 2022) when NFE is extremely small (NFE15) visually and numerically on AFHQv2 dataset. We also compare with other fast sampling algorithms built upon DM in table.5 on CIFAR-10 dataset where AGM-ODE demonstrates competitive performance. Notably, AGM-ODE outperforms the baseline CLD with the same EI sampler by a large margin. We suspect that the improvement is based on the rectified trajectory which is more friendly for the ODE solver.
Conditional Generation We showcase the capability of AGM to generate conditional samples using an unconditional model (fig.5) by incorporating conditional information into the prior velocity variable . Instead of employing a randomly sampled , we use a linear combination of and the desired velocity , where is conditioned data. Thus, , the initial velocity is defined as , with serving as a mixing coefficient. Fig.5 shows that AGM can generate conditional data without augmentation and additional fine-tuning. Such property can be extended to the inpainting task as well and the detail can be found in appendix.F.
Model Name | NFE | FID | |
---|---|---|---|
ODE | EDM (Karras et al., 2022) | 35 | 1.84 |
CLD+EI (Zhang et al., 2022) | 50 | 2.26 | |
FM-OT (Lipman et al., 2022) | 142 | 6.35 | |
AGM-ODE(ours) | 50 | 2.46 | |
SDE | VP (Song et al., 2020b) | 1000 | 2.66 |
VE (Song et al., 2020b) | 1000 | 2.43 | |
CLD (Dockhorn et al., 2021) | 1000 | 2.44 | |
AGM-SDE(ours) | 1000 | 2.46 |
NFE | 5 | 10 | 20 | ||
Dynamics Order | Model Name | ||||
1st order dynamics | EDM (Karras et al., 2022) | 100 | 15.78 | 2.23 | |
VP+EI (Zhang & Chen, 2022) | 15.37 | 4.17 | 3.03 | ||
DDIM (Song et al., 2020a) | 26.91 | 11.14 | 3.50 | ||
Analytic-DPM(Bao et al., 2022) | 51.47 | 14.06 | 6.74 | ||
2nd order dynamics | CLD+EI (Zhang et al., 2022) | N/A | 13.41 | 3.39 | |
AGM-ODE(ours) | 11.93 | 4.60 | 2.60 |
5 Conclusion and Limitation
In this paper, we introduce a novel Acceleration Generative Modeling (AGM) framework rooted in SOC theory. Within this framework, we devise more favorable, straight trajectories for the momentum system. Leveraging the intrinsic characteristics of the momentum system, we capitalize on additional velocity to expedite the sampling process by using the sampling-hop technique, significantly reducing the time required to converge to accurate predictions of realistic data points. Our experimental results, conducted on both toy and image datasets in unconditional generative tasks, demonstrate promising outcomes for fast sampling.
However, it is essential to acknowledge that our approach’s performance lags behind state-of-the-art methods in scenarios with sufficient NFE. This observation suggests avenues for enhancing AGM performance. Such improvements could be achieved by enhancing the training quality through the adoption of techniques proposed in Karras et al. (2022) including data augmentation, fine-tuned noise scheduling, and network preconditioning, among others.
References
- Anderson (1982) Brian DO Anderson. Reverse-time diffusion equation models. Stochastic Processes and their Applications, 12(3):313–326, 1982.
- Bao et al. (2022) Fan Bao, Chongxuan Li, Jun Zhu, and Bo Zhang. Analytic-dpm: an analytic estimate of the optimal reverse variance in diffusion probabilistic models. arXiv preprint arXiv:2201.06503, 2022.
- Bryson (1975) Arthur Earl Bryson. Applied optimal control: optimization, estimation and control. CRC Press, 1975.
- Chen et al. (2023) Tianrong Chen, Guan-Horng Liu, Molei Tao, and Evangelos A Theodorou. Deep momentum multi-marginal schr” odinger bridge. arXiv preprint arXiv:2303.01751, 2023.
- Chen & Georgiou (2015) Yongxin Chen and Tryphon Georgiou. Stochastic bridges of linear systems. IEEE Transactions on Automatic Control, 61(2):526–531, 2015.
- Choi et al. (2020) Yunjey Choi, Youngjung Uh, Jaejun Yoo, and Jung-Woo Ha. Stargan v2: Diverse image synthesis for multiple domains. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 8188–8197, 2020.
- De Bortoli et al. (2023) Valentin De Bortoli, Guan-Horng Liu, Tianrong Chen, Evangelos A Theodorou, and Weilie Nie. Augmented bridge matching. arXiv preprint arXiv:2311.06978, 2023.
- Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp. 248–255. Ieee, 2009.
- Dhariwal & Nichol (2021) Prafulla Dhariwal and Alex Nichol. Diffusion models beat gans on image synthesis. arXiv preprint arXiv:2105.05233, 2021.
- Dockhorn et al. (2021) Tim Dockhorn, Arash Vahdat, and Karsten Kreis. Score-based generative modeling with critically-damped langevin diffusion. arXiv preprint arXiv:2112.07068, 2021.
- Haussmann & Pardoux (1986) Ulrich G Haussmann and Etienne Pardoux. Time reversal of diffusions. The Annals of Probability, pp. 1188–1205, 1986.
- Heng et al. (2021) Jeremy Heng, Valentin De Bortoli, Arnaud Doucet, and James Thornton. Simulating diffusion bridges with score matching. arXiv preprint arXiv:2111.07243, 2021.
- Heusel et al. (2017) Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in neural information processing systems, 30, 2017.
- Ho et al. (2020) Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. arXiv preprint arXiv:2006.11239, 2020.
- Hochbruck & Ostermann (2010) Marlis Hochbruck and Alexander Ostermann. Exponential integrators. Acta Numerica, 19:209–286, 2010.
- Inc. (2022) The MathWorks Inc. Matlab version: 9.13.0 (r2022b), 2022. URL https://www.mathworks.com.
- Kappen (2008) HJ Kappen. Stochastic optimal control theory. ICML, Helsinki, Radbound University, Nijmegen, Netherlands, 2008.
- Karras et al. (2022) Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35:26565–26577, 2022.
- Kirk (2004) Donald E Kirk. Optimal control theory: an introduction. Courier Corporation, 2004.
- Kloeden et al. (1992) Peter E Kloeden, Eckhard Platen, Peter E Kloeden, and Eckhard Platen. Stochastic differential equations. Springer, 1992.
- Krizhevsky et al. (2009) Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
- Léonard et al. (2014) Christian Léonard, Sylvie Rœlly, and Jean-Claude Zambrini. Reciprocal processes. a measure-theoretical point of view. 2014.
- Lipman et al. (2022) Yaron Lipman, Ricky TQ Chen, Heli Ben-Hamu, Maximilian Nickel, and Matt Le. Flow matching for generative modeling. arXiv preprint arXiv:2210.02747, 2022.
- Liu et al. (2023) Guan-Horng Liu, Arash Vahdat, De-An Huang, Evangelos A Theodorou, Weili Nie, and Anima Anandkumar. I2sb: Image-to-image schr” odinger bridge. arXiv preprint arXiv:2302.05872, 2023.
- Liu et al. (2022) Xingchao Liu, Lemeng Wu, Mao Ye, and Qiang Liu. Let us build bridges: Understanding and extending diffusion generative models. arXiv preprint arXiv:2208.14699, 2022.
- Loshchilov & Hutter (2017) Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017.
- O’Connell (2003) Neil O’Connell. Conditioned random walks and the rsk correspondence. Journal of Physics A: Mathematical and General, 36(12):3049, 2003.
- Øksendal (2003) Bernt Øksendal. Stochastic differential equations. In Stochastic differential equations, pp. 65–84. Springer, 2003.
- Pandey et al. (2023) Kushagra Pandey, Maja Rudolph, and Stephan Mandt. Efficient integrators for diffusion generative models. arXiv preprint arXiv:2310.07894, 2023.
- Peluchetti (2021) Stefano Peluchetti. Non-denoising forward-time diffusions. 2021.
- Peluchetti (2023) Stefano Peluchetti. Diffusion bridge mixture transports, schr” odinger bridge problems and generative modeling. arXiv preprint arXiv:2304.00917, 2023.
- Pooladian et al. (2023) Aram-Alexandre Pooladian, Heli Ben-Hamu, Carles Domingo-Enrich, Brandon Amos, Yaron Lipman, and Ricky Chen. Multisample flow matching: Straightening flows with minibatch couplings. arXiv preprint arXiv:2304.14772, 2023.
- Revuz & Yor (2013) Daniel Revuz and Marc Yor. Continuous martingales and Brownian motion, volume 293. Springer Science & Business Media, 2013.
- Särkkä & Solin (2019) Simo Särkkä and Arno Solin. Applied stochastic differential equations, volume 10. Cambridge University Press, 2019.
- Shi et al. (2022) Yuyang Shi, Valentin De Bortoli, George Deligiannidis, and Arnaud Doucet. Conditional simulation using diffusion schrödinger bridges. In Uncertainty in Artificial Intelligence, pp. 1792–1802. PMLR, 2022.
- Shi et al. (2023) Yuyang Shi, Valentin De Bortoli, Andrew Campbell, and Arnaud Doucet. Diffusion schr” odinger bridge matching. arXiv preprint arXiv:2303.16852, 2023.
- Somnath et al. (2023) Vignesh Ram Somnath, Matteo Pariset, Ya-** Hsieh, Maria Rodriguez Martinez, Andreas Krause, and Charlotte Bunne. Aligned diffusion schr” odinger bridges. arXiv preprint arXiv:2302.11419, 2023.
- Song et al. (2020a) Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. arXiv preprint arXiv:2010.02502, 2020a.
- Song et al. (2020b) Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, 2020b.
- Song et al. (2021) Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. Maximum likelihood training of score-based diffusion models. arXiv e-prints, pp. arXiv–2101, 2021.
- Stengel (1994) Robert F Stengel. Optimal control and estimation. Courier Corporation, 1994.
- Tong et al. (2023a) Alexander Tong, Nikolay Malkin, Kilian Fatras, Lazar Atanackovic, Yanlei Zhang, Guillaume Huguet, Guy Wolf, and Yoshua Bengio. Simulation-free schr” odinger bridges via score and flow matching. arXiv preprint arXiv:2307.03672, 2023a.
- Tong et al. (2023b) Alexander Tong, Nikolay Malkin, Guillaume Huguet, Yanlei Zhang, Jarrid Rector-Brooks, Kilian Fatras, Guy Wolf, and Yoshua Bengio. Improving and generalizing flow-based generative models with minibatch optimal transport. In ICML Workshop on New Frontiers in Learning, Control, and Dynamical Systems, 2023b.
- Yong & Zhou (1999) Jiongmin Yong and Xun Yu Zhou. Stochastic controls: Hamiltonian systems and HJB equations, volume 43. Springer Science & Business Media, 1999.
- Zhang & Chen (2022) Qinsheng Zhang and Yongxin Chen. Fast sampling of diffusion models with exponential integrator. arXiv preprint arXiv:2204.13902, 2022.
- Zhang et al. (2022) Qinsheng Zhang, Molei Tao, and Yongxin Chen. gddim: Generalized denoising diffusion implicit models. arXiv preprint arXiv:2206.05564, 2022.
- Zhang et al. (2023) Qinsheng Zhang, Jiaming Song, and Yongxin Chen. Improved order analysis and design of exponential integrator for diffusion models sampling. arXiv preprint arXiv:2308.02157, 2023.
Appendix A supplementary Summary
Appendix B Assumptions
We will use the following assumptions to construct the proposed method. These assumptions are adopted from stochastic analysis for SGM (Song et al., 2021; Yong & Zhou, 1999; Anderson, 1982),
-
(i)
and with finite second-order moment.
-
(ii)
is continuous functions, and is uniformly lower-bounded w.r.t. .
-
(iii)
, we have Lipschitz and at most linear growth w.r.t. and .
Assumptions (i) (ii) are standard conditions in stochastic analysis to ensure the existence-uniqueness of the SDEs; hence also appear in SGM analysis (Song et al., 2021).
Appendix C Stochastic Optimal Control (SOC) in the Wild
In this section, we are going to provide a gentle introduction of Stochastic Optimal Control (SOC). Our work is majorly relying on the prior work Chen & Georgiou (2015) in which some technical details are missing. Here we first clarify some core derivations that may help the broader audience to understand Chen & Georgiou (2015) and our work.
C.1 Linear Quadratic Stochastic Optimal Control
SOC has wide applications in finance, robotics, and manufacturing. Here we will focus on Linear Quadratic SOC which usually refers to Linear Quadratic Regulator because the dynamic is linear and the objective function is quadratic (Bryson, 1975; Stengel, 1994). The problem states as:
(13) | ||||
In this formulation, means the state and is the control variable. Conceptually, the SOC problem is aiming to design the controller to drive the system from point to with minimum effort. In the case of first-order system, the control will be the optimal vector field and for the second-order system, the control is denoted as the optimal acceleration . The presence of stochasticity, introduced by the Wiener Process denoted as , prevents the system from precisely converging to the Dirac mass . In order to strike a balance between the objective of converging to and minimizing overall control effort , the terminal cost has been imposed.
One special case is . Intuitively, it means the controlled dynamics should precisely converge to . However, one can notice that the stochastic trajectory which connects and is not unique in this case. Based on this constraint (pinned down at and at two boundaries), the optimization problem of SOC finds the optimal solution with minimum effort which can be understood as the regularization of the trajectories, hence, such stochastic trajectory is unique while the regularization of controller is still applied. One can also draw the connection with such pinned-down SDE with well-known Doob- transform. For the people who are not familiar with these, here are some interesting papers (Heng et al., 2021; O’Connell, 2003).
The classical procedure to solve the SOC problem includes:
-
1.
write down the Hamilton–Jacobi–Bellman equation (HJB PDE) which explicitly represents the propagation of value function over time.
-
2.
Construct the Ricatti/Lyapunov Equation.
-
3.
Solve Ricatti/Lyapunov Equation and obtain the optimal control.
C.2 Value Function, Hamilton-Jacobian (Hamilton–Jacobi–Bellman equation) and Ricatti Equation
We adopt the classical notation in the SOC for the value function. Specifically, the underscript of the value function represents the partial derivative of it. For example, , and represent for the first order derivative of w.r.t time , state and second order derivate of w.r.t . We first define the value function as:
and the dynamics is,
From Bellman’s principle to the value function, one can get:
One obtain:
The optimal control can be obtained by
Plugging it back, one can obtain the HJB PDE:
We assume that there exist certain matrix , s.t. . By matching the different power terms of HJB, one can write:
(14) |
with boundary condition:
(15) |
Due to the fact that , one arrives Riccati Equation:
(16) |
Recall that the optimal solution is and , the optimal control can be expressed in the way of the solution of Ricatti equation: .
C.3 Ricatti Equation and Lyapunov Equation
Here we provide the connection between Ricatti Equation and Lyapunov Equation in the current setup.
Lemma 6.
Define in which is the solution of Ricatti equation (eq.16), Then solve the Lyapunov equation:
(17) |
For notation consistency, we name the elements in matrix as,
Proof.
By plugging in the Lyapunov equation , one can get:
∎
By Lemma.6, the optimal control can also be represented as the solution of the Lyapunov equation: which is indeed the optimal control term used in Chen & Georgiou (2015) after adopting their notation, and it is same as the optimal control term we used in the Lemma.12 without base dynamics compensation.
C.4 SOC Connection with Schrödinger Bridge
The optimal control solution is also the solution of Schrödinger Bridge when the terminal condition degenerates to the point mass (see example of Brownian Bridge in Appendix.D.1). It is also the solution of the Schrödinger Bridge when the optimal pairing is available to see proposition.2 De Bortoli et al. (2023).
So in our case, we are not solving the momentum Schrödinger Bridge as shown in Chen et al. (2023) (also see. fig.6), even though the problem formulation is similar. Specifically, AGM is a special case of momentum Schrödinger Bridge when the boundary conditions are degenerated to Dirac Distributions.
Appendix D Technique Details in Section.3
D.1 Brownian Bridge as the solution of Stochastic Optimal Control
We adopt the presentation form Kappen (2008). We consider the control problem:
Where is the terminal cost coefficient. According to Pontryagin Maximum Principle (PMP;Kirk (2004)) recipe, one can construct the Hamiltonian:
By setting:
the optimized Hamiltonian is:
Then we solve the Hamiltonian equation of motion:
One can notice that the solution for is the constant , hence the solution for is .
When , we arrive the optimal control as . Due to certainty equivalence, this is also the optimal control law for
By plugging it back into the dynamics, we obtain the well-known Brownian Bridge:
Remark 7.
If there is not stochasticity , one can get which is the vector field constructed by Lipman et al. (2022) during traning.
D.2 Proof of Proposition.3
Proposition 8.
The solution of the stochastic bridge problem of linear momentum system (Chen & Georgiou, 2015) is
(18) |
Proof.
From Lemma.12, one can get the optimal control for this problem is
where state transition function can be obtained from Lemma.11 and is the solution of Lyapunov equation and can be found in Lemma.9.
Then we have:
∎
Lemma 9.
The Lyapunov equation corresponding to the optimization problem showed in Lemma.12:
is depited as
(19) |
When , the solution for Lyapunov equation above, with terminal condition
(20) |
However, one does not need the force to converge exactly at because we only care about the generated quality of . Here we give a general case in which the keeps a small value for the velocity channel:
(21) |
Then the solution is given by
and the inverse of is,
Thus,
Proof.
One can plug in the solution of into the Lyapunov equation and it validates is indeed the solution.
Remark 10.
Here we provide a general form when the terminal condition of the Lyapunov function is not a zero matrix. It explicitly means that it allows that the velocity does not necessarily need to converge to the exact predefined . It will have the same results as shown in the paper by setting .
∎
Lemma 11.
The state transition function of following dynamics,
is,
Proof.
One can easily verify that such satisfies . ∎
Lemma 12 (Chen & Georgiou (2015)).
When , The optimal control of following problem,
is given by
Where follows Lyapunov equation (eq.19) with boundary condition . and function is the transition matrix from time-step to time-step given uncontrolled dynamics.
And it is indeed the stochastic bridge of the following system:
(22) | ||||
(23) |
Proof.
See page 8 in Chen & Georgiou (2015). ∎
D.3 Mean and Covariance of SDE
By plugging the optimal control into the system, one can obtain the system as:
We follow the recipe of Särkkä & Solin (2019). The mean and variance of the matrix of random variable obey the following respective ordinary differential equations (ODEs):
One can solve it by numerically simulating two ODEs whose dimension is just two. Or one can use software such as Inc. (2022) to get analytic solutions. If you opt to the later approach, you can get:
Remark 13.
The expressions above are too complicated. Hence, we provide the Python functional bracket in Appendix.E.1 with general initial covariance and diffusion coefficient for easy copy-paste. The equations above are ones we used throughout this paper and feel free to play around with other hyperparameters.
D.4 Derivation from SDE to ODE for phase dynamics
D.5 Decomposition of Covariance Matrix and representation of score
Here we follow the procedure in Dockhorn et al. (2021). Given the covariance matrix , the decomposition of the positive definite symmetric matrix is,
(30) |
Where,
(31) |
We borrow results from Dockhorn et al. (2021), the score function reads,
Cholesky decomposition of | |||
The form of reads,
and the transpose inverse of reads,
Hence, the score function reads,
D.6 Representation of acceleration
As been shown in Proposition.3, the optimal control can be represented as,
D.7 Loss Reweight
In practice, we use the following loss function
(32) | |||
(33) |
We admit that this might not be an optimal selection. The motivation behind this is simply increasing the weight of training when and normalize the label with normalizer .
D.8 Normalizer of AGM-SDE and AGM-ODE
Since the optimal control term can be represented as,
Then we introduce the normalizer as
Where
D.9 Exponential Integrator Derivation
As suggested by Zhang & Chen (2022), one can write the discretized dynamics as,
(34) | ||||
After plugging in the transition kernel , one can easily obtain the results shown in (11).
Remark 14.
In light of the momentum system, there are numerous methods for achieving high accuracy in its resolution. However, the practical performance in generative modeling remains untested. We recommend that readers consult the classical numerical physics text book or recent momentum dynamics solver (Pandey et al., 2023; Dockhorn et al., 2021).
D.10 Proof of Proposition.5
The estimated data point can be represented as
(35) |
for SDE and probablistic ODE dynamics respectively, and ,.
Proof.
It is easy to derive the representation of of the SDE due to the fact that the network is essentially estimating:
It will become slightly more complicated for probabilistic ODE cases. We notice that
In probabilistic ODE case, the force term can be represented as,
In order to use linear combination of and to represent one needs to match the stochastic term in by using
The solution can be obtained by:
By subsitute it back to , one can get:
∎
Appendix E Experimental Details
Training: We stick with hyperparameters introduced in the section.4. We use AdamW(Loshchilov & Hutter, 2017) as our optimizer and Exponential Moving Averaging with the exponential decay rate of 0.9999. We use 8 Nvidia A100 GPU for all experiments. For further, training setup, please refer to Table.6.
dataset | Training Iter | Learning rate | Batch Size | network architecture |
---|---|---|---|---|
toy | 0.05M | 1e-3 | 1024 | ResNet(Dockhorn et al., 2021) |
CIFAR-10 | 0.5M | 1e-3 | 512 | NCSN++(Karras et al., 2022) |
AFHQv2 | 0.5M | 1e-3 | 512 | NCSN++(Karras et al., 2022) |
ImageNet-64 | 1.6M | 2e-4 | 512 | ADM(Dhariwal & Nichol, 2021) |
Sampling: For Exponential Integrator, we choose the multistep order consistently for all experiments. Different from previous work (Dockhorn et al., 2021; Karras et al., 2022; Zhang et al., 2023), we use quadratic timesteps scheme with :
Which is opposite to the classical DM. Namely, the time discretization will get larger when the dynamics is propagated close to data. For numerical stability, we use for all experiments. For , we use and , . For the rest of the sampling, we use .
Due to the fact that EDM(Karras et al., 2022) is using second-order ODE solver, in practice, we allow it to have an extra one NFE as reported for all the tables.
E.1 Code Example for Covariance
We will abuse the notation in this coding section. Here we provide the example code for computing the covariance matrix. Here we consider the general case where and the diffusion coefficient is where is the scaling coefficient and is the dam** coefficient.
def Sigmaxx(t,p,tt,m,n): return \ (t - 1)**2*(30*m*(t**3 - 3*t**2 + 3*t + 3)**2\ - 60*p**2*(t - 1)**3*torch.log(1 - t) \ - t*(60*k*np.sqrt(m*n)*(t**5 - 6*t**4 + 15*t**3 - 15*t**2 + 9)\ - 30*n*t*(t**2 - 3*t + 3)**2 + p**2*(t**5*(6*tt**2 + 3*tt + 1) \ - 6*t**4*(6*tt**2 + 3*tt + 1)\ + 15*t**3*(6*tt**2 + 3*tt + 1)\ - 10*t**2*(9*tt**2 + 11) + 150*t - 60)))/270 def Sigmaxv(t,p,tt,m,n): return \ (1/270 - t/270)*(30*k*np.sqrt(m*n)*(8*t**6 - 48*t**5\ + 120*t**4 - 135*t**3 + 45*t**2 + 27*t - 9) +\ 150*p**2*(t - 1)**3*torch.log(1 - t)\ + t*(-120*m*(t**5 - 6*t**4 + 15*t**3 - 15*t**2 + 9)\ - 30*n*(4*t**5 - 24*t**4 + 60*t**3 - 75*t**2 + 45*t - 9)\ + p**2*(4*t**5*(6*tt**2 + 3*tt + 1) - 24*t**4*(6*tt**2 + 3*tt + 1)\ + 60*t**3*(6*tt**2 + 3*tt + 1) - 5*t**2*(81*tt**2 + 18*tt + 55)\ + 15*t*(9*tt**2 + 25) - 150))) def Sigmavv(t,p,tt,m,n): return \ n*(-4*t**3 + 12*t**2 - 12*t + 3)**2/9\ - 8*p**2*(t - 1)**3*torch.log(1 - t)/9\ + t*(-120*k*np.sqrt(m*n)*(4*t**5 - 24*t**4 + 60*t**3\ - 75*t**2 + 45*t - 9) + 240*m*t*(t**2 - 3*t + 3)**2 \ + p**2*(-8*t**5*(6*tt**2 + 3*tt + 1) + 48*t**4*(6*tt**2 + 3*tt + 1)\ - 120*t**3*(6*tt**2 + 3*tt + 1) + 5*t**2*(180*tt**2 + 72*tt + 53)\ - 15*t*(36*tt**2 + 9*tt + 20) + 135*tt**2 + 120))/135
Appendix F Conditional Generation Details
Here we provide the details of conditional generation details.
F.1 Storke Based Generation
For stroke-based generation, we provide two types of conditional generation.
initial Velocity (IV):Please refer to section.4.
Dynamics Velocity (dyn-V): Since the mean and variance of velocity and position are available, one can specify the velocity which is valid. In this case, we can set the velocity as
(36) |
In which,
(37) | ||||
(38) |
when . The is the guidance length. We typically set it to be .
F.2 Inpainting
In the inpainting case, we apply a similar strategy as dyn-V. Specifically, in this case, the will be represented as:
(39) |
where MASK represents the mask matrix which zero out the pixel of the original image. Such will serve as the source to estimate in eq.37.
F.3 inpainting Based Generation
For stroke-based generation, we provide two types of conditional generation.
Appendix G Ablation Study of Stoke-Based Conditional Generation
In order to investigate the diversity and faithfulness of stoke-based conditional generation, we conduct the ablation study with respect to the hyperparameter .
Appendix H Additional Figures
We demonstrate the samples for different datasets with varying NFE.