HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: arydshln

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY 4.0
arXiv:2302.03660v3 [cs.LG] 26 Feb 2024

Flow Matching on General Geometries

Ricky T. Q. Chen
FAIR, Meta
[email protected] &Yaron Lipman
FAIR, Meta and Weizmann Institute of Science
[email protected]
Abstract

We propose Riemannian Flow Matching (RFM), a simple yet powerful framework for training continuous normalizing flows on manifolds. Existing methods for generative modeling on manifolds either require expensive simulation, are inherently unable to scale to high dimensions, or use approximations for limiting quantities that result in biased training objectives. Riemannian Flow Matching bypasses these limitations and offers several advantages over previous approaches: it is simulation-free on simple geometries, does not require divergence computation, and computes its target vector field in closed-form. The key ingredient behind RFM is the construction of a relatively simple premetric for defining target vector fields, which encompasses the existing Euclidean case. To extend to general geometries, we rely on the use of spectral decompositions to efficiently compute premetrics on the fly. Our method achieves state-of-the-art performance on many real-world non-Euclidean datasets, and we demonstrate tractable training on general geometries, including triangular meshes with highly non-trivial curvature and boundaries.

1 Introduction

Refer to caption
Refer to caption
Refer to caption
Geodesic
Refer to caption
Biharmonic
Figure 1: Our approach makes use of user-specified premetrics on general manifolds to define flows. On select simple manifolds, the geodesic can be computed exactly and leads to a simulation-free algorithm. On general manifolds where the geodesic is not only computationally expensive but can lead to degeneracy (e.g., along boundaries), we propose the use of spectral distances (e.g., biharmonic), which can be computed efficiently contingent on a one-time processing cost.

While generative models have recently made great advances in fitting data distributions in Euclidean spaces, there are still challenges in dealing with data residing in non-Euclidean spaces, specifically on general manifolds. These challenges include scalability to high dimensions (e.g., (Rozen et al., 2021)), the requirement for simulation or iterative sampling during training even for simple geometries like hyperspheres (e.g., (Mathieu & Nickel, 2020; De Bortoli et al., 2022)), and difficulties in constructing simple and scalable training objectives.

In this work, we introduce Riemannian Flow Matching (RFM), a simple yet powerful methodology for learning continuous normalizing flows (CNFs; (Chen et al., 2018)) on general Riemannian manifolds \mathcal{M}caligraphic_M. RFM builds upon the Flow Matching framework (Lipman et al., 2023; Albergo & Vanden-Eijnden, 2023; Liu et al., 2023) and learns a CNF by regressing an implicitly defined target vector field ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) that pushes a base distribution p𝑝pitalic_p towards a target distribution q𝑞qitalic_q defined by the training examples. To address the intractability of ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ), we employ a similar approach to Conditional Flow Matching (Lipman et al., 2023), where we regress onto conditional vector fields ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) that push p𝑝pitalic_p towards individual training examples x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

A key observation underlying our Riemannian generalization is that the conditional vector field necessary for training the CNF can be explicitly expressed in terms of a “premetric” d(x,y)d𝑥𝑦\textrm{d}(x,y)d ( italic_x , italic_y ), which distinguishes pairs of points x𝑥xitalic_x and y𝑦yitalic_y on the manifold. A natural choice for such a premetric is the geodesic distance function, which coincides with the straight trajectories previously used in Euclidean space by prior approaches.

On simple geometries, where geodesics are known in closed form (e.g., Euclidean space, hypersphere, hyperbolic space, torus, or any of their product spaces), Riemannian Flow Matching remains completely simulation-free. Even on general geometries, it only requires forward simulation of a relatively simple ordinary differential equation (ODE), without differentiation through the solver, stochastic iterative sampling, or divergence estimation.

Simulation-free on simple geo. Closed-form target vector field Does not require divergence
Ben-Hamu et al. (2022) -
De Bortoli et al. (2022) (DSM)
De Bortoli et al. (2022) (ISM) -
Huang et al. (2022) -
Riemannian FM (Ours)
Table 1: Comparison of closely related methods for training continuous-time generative models on Riemannian manifolds. Additionally, we are the only among these works to consider and tackle general geometries.

On all types of geometries, Riemannian Flow Matching offers several advantages over recently proposed Riemannian diffusion models (De Bortoli et al., 2022; Huang et al., 2022). These advantages include avoiding iterative simulation of a noising process during training for geometries with analytic geodesic formulas; not relying on approximations of score functions or divergences of the parameteric vector field; and not needing to solve stochastic differential equations (SDE) on manifolds, which is generally more challenging to approximate than ODE solutions (Kloeden et al., 2002; Hairer et al., 2006; Hairer, 2011). Table 1 summarizes the key differences with relevant prior methods, which we expand on further in Section 4 (Related Work).

Empirically, we find that Riemannian Flow Matching achieves state-of-the-art performance on manifold datasets across various settings, being on par or outperforming competitive baselines. We also demonstrate that our approach scales to higher dimensions without sacrificing performance, thanks to our scalable closed-form training objective. Moreover, we present the first successful training of continuous-time deep generative models on non-trivial geometries, including those imposed by discrete triangular meshes and manifolds with non-trivial boundaries that represent challenging constraints on maze-shaped manifolds.

2 Preliminaries

Riemannian manifolds.

This paper considers complete connected, smooth Riemannian manifolds {\mathcal{M}}caligraphic_M with metric g𝑔gitalic_g as basic domain over which the generative model is learned. Tangent space to {\mathcal{M}}caligraphic_M at x𝑥x\in{\mathcal{M}}italic_x ∈ caligraphic_M is denoted Txsubscript𝑇𝑥T_{x}{\mathcal{M}}italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M, and g𝑔gitalic_g defines an inner product over Txsubscript𝑇𝑥T_{x}{\mathcal{M}}italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M denoted u,vgsubscript𝑢𝑣𝑔\left\langle u,v\right\rangle_{g}⟨ italic_u , italic_v ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, u,vTx𝑢𝑣subscript𝑇𝑥u,v\in T_{x}{\mathcal{M}}italic_u , italic_v ∈ italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M. T=x{x}×Tx𝑇subscript𝑥𝑥subscript𝑇𝑥T{\mathcal{M}}=\cup_{x\in{\mathcal{M}}}\left\{x\right\}\times T_{x}{\mathcal{M}}italic_T caligraphic_M = ∪ start_POSTSUBSCRIPT italic_x ∈ caligraphic_M end_POSTSUBSCRIPT { italic_x } × italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M is the tangent bundle that collects all the tangent planes of the manifold. 𝒰={ut}𝒰subscript𝑢𝑡\mathcal{U}=\left\{u_{t}\right\}caligraphic_U = { italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } denotes the space of time dependent smooth vector fields (VFs) ut:[0,1]×T:subscript𝑢𝑡01𝑇u_{t}:[0,1]\times{\mathcal{M}}\rightarrow T{\mathcal{M}}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT : [ 0 , 1 ] × caligraphic_M → italic_T caligraphic_M, where ut(x)Txsubscript𝑢𝑡𝑥subscript𝑇𝑥u_{t}(x)\in T_{x}{\mathcal{M}}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∈ italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M for all x𝑥x\in{\mathcal{M}}italic_x ∈ caligraphic_M; divg(ut)subscriptdiv𝑔subscript𝑢𝑡\mathrm{div}_{g}(u_{t})roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is the Riemannian divergence w.r.t. the spatial (x𝑥xitalic_x) argument. We will denote by dvolx𝑑subscriptvol𝑥d\mathrm{vol}_{x}italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT the volume element over {\mathcal{M}}caligraphic_M, and integration of a function f::𝑓f:{\mathcal{M}}\rightarrow\mathbb{R}italic_f : caligraphic_M → blackboard_R over {\mathcal{M}}caligraphic_M is denoted f(x)𝑑volx𝑓𝑥differential-dsubscriptvol𝑥\int f(x)d\mathrm{vol}_{x}∫ italic_f ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT. For readers who are looking for a more comprehensive background on Riemannian manifolds, we recommend Gallot et al. (1990).

Probability paths and flows on manifolds.

Probability densities over {\mathcal{M}}caligraphic_M are continuous non-negative functions p:+:𝑝subscriptp:{\mathcal{M}}\rightarrow\mathbb{R}_{+}italic_p : caligraphic_M → blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT such that p(x)𝑑volx=1𝑝𝑥differential-dsubscriptvol𝑥1\int p(x)d\mathrm{vol}_{x}=1∫ italic_p ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = 1. The space of probability densities over {\mathcal{M}}caligraphic_M is marked 𝒫𝒫\mathcal{P}caligraphic_P. A probability path ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a curve in probability space pt:[0,1]𝒫:subscript𝑝𝑡01𝒫p_{t}:[0,1]\rightarrow\mathcal{P}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT : [ 0 , 1 ] → caligraphic_P; such paths will be used as supervision signal for training our generative models. A flow is a diffeomorphism Ψ::Ψ\Psi:{\mathcal{M}}\rightarrow{\mathcal{M}}roman_Ψ : caligraphic_M → caligraphic_M defined by integrating instantaneous deformations represented by a time-dependent vector field ut𝒰subscript𝑢𝑡𝒰u_{t}\in\mathcal{U}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_U. Specifically, a time-dependent flow, ψt::subscript𝜓𝑡\psi_{t}:{\mathcal{M}}\rightarrow{\mathcal{M}}italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT : caligraphic_M → caligraphic_M, is defined by solving the following ordinary differential equation (ODE) on {\mathcal{M}}caligraphic_M over t[0,1]𝑡01t\in[0,1]italic_t ∈ [ 0 , 1 ],

ddtψt(x)=ut(ψt(x)),ψ0(x)=x,formulae-sequence𝑑𝑑𝑡subscript𝜓𝑡𝑥subscript𝑢𝑡subscript𝜓𝑡𝑥subscript𝜓0𝑥𝑥\frac{d}{dt}\psi_{t}(x)=u_{t}(\psi_{t}(x)),\quad\quad\psi_{0}(x)=x,divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ) , italic_ψ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) = italic_x , (1)

and the final diffeomorphism is defined by setting Ψ(x)=ψ1(x)Ψ𝑥subscript𝜓1𝑥\Psi(x)=\psi_{1}(x)roman_Ψ ( italic_x ) = italic_ψ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ). Given a probability density path ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, it is said to be generated by utsubscript𝑢𝑡u_{t}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from p𝑝pitalic_p if ψtsubscript𝜓𝑡\psi_{t}italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT pushes p0=psubscript𝑝0𝑝p_{0}=pitalic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_p to ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT for all t[0,1]𝑡01t\in[0,1]italic_t ∈ [ 0 , 1 ]. More formally,

logpt(x)=log([ψt]p)(x)=logp(ψt1(x))0tdivg(ut)(xs)dssubscript𝑝𝑡𝑥subscriptdelimited-[]subscript𝜓𝑡𝑝𝑥𝑝superscriptsubscript𝜓𝑡1𝑥superscriptsubscript0𝑡subscriptdiv𝑔subscript𝑢𝑡subscript𝑥𝑠differential-d𝑠\log p_{t}(x)=\log([\psi_{t}]_{\sharp}p)(x)=\log p(\psi_{t}^{-1}(x))-\int_{0}^% {t}\mathrm{div}_{g}(u_{t})(x_{s}){\mathrm{d}}sroman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = roman_log ( [ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT italic_p ) ( italic_x ) = roman_log italic_p ( italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ) roman_d italic_s (2)

where the \sharp symbol denotes the standard push-forward operation and xs=ψs(ψt1(x))subscript𝑥𝑠subscript𝜓𝑠superscriptsubscript𝜓𝑡1𝑥x_{s}=\psi_{s}(\psi_{t}^{-1}(x))italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ). This formula can be derived from the Riemannian version of the instantaneous change of variables Formula (see equation 22 in (Ben-Hamu et al., 2022)). Previously, Chen et al. (2018) suggested modeling the flow ψtsubscript𝜓𝑡\psi_{t}italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT implicitly by considering parameterizing the vector field utsubscript𝑢𝑡u_{t}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. This results in a deep generative model of the flow ψtsubscript𝜓𝑡\psi_{t}italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, called a Continuous Normalizing Flow (CNF) which models a probability path ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT through a continuous-time deformation of a base distribution p𝑝pitalic_p. A number of works have formulated manifold variants (Mathieu & Nickel, 2020; Lou et al., 2020; Falorsi, 2020) that require simulation in order to enable training, while some simulation-free variants (Rozen et al., 2021; Ben-Hamu et al., 2022) scale poorly to high dimensions and do not readily adapt to general geometries.

3 Method

We aim to train a generative model that lies on a complete, connected smooth Riemannian manifold {\mathcal{M}}caligraphic_M endowed with a metric g𝑔gitalic_g. Concretely, we are given a set of training samples x1subscript𝑥1x_{1}\in{\mathcal{M}}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ caligraphic_M from some unknown data distribution q(x1)𝑞subscript𝑥1q(x_{1})italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), q𝒫𝑞𝒫q\in\mathcal{P}italic_q ∈ caligraphic_P. Our goal is to learn a parametric map Φ::Φ\Phi:{\mathcal{M}}\rightarrow{\mathcal{M}}roman_Φ : caligraphic_M → caligraphic_M that pushes a simple base distribution p𝒫𝑝𝒫p\in\mathcal{P}italic_p ∈ caligraphic_P to q𝑞qitalic_q.

3.1 Flow Matching on Manifolds

Flow Matching Lipman et al. (2023) is a method to train Continuous Normalizing Flow (CNF) on Euclidean space that sidesteps likelihood computation during training and scales extremely well, similar to diffusion models (Ho et al., 2020; Song et al., 2020b), while allowing the design of more general noise processes which enables this work. We provide a brief summary and make the necessary adaptation to formulate Flow Matching on Riemannian manifolds. Derivations of the manifold case with full technical details are in Appendix A.

Riemannian Flow Matching.

Flow Matching (FM) trains a CNF by fitting a vector field v𝒰𝑣𝒰v\in\mathcal{U}italic_v ∈ caligraphic_U, i.e., vt(x)Txsubscript𝑣𝑡𝑥subscript𝑇𝑥v_{t}(x)\in T_{x}{\mathcal{M}}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∈ italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M, with parameters θp𝜃superscript𝑝\theta\in\mathbb{R}^{p}italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT, to an a priori defined target vector field u𝒰𝑢𝒰u\in\mathcal{U}italic_u ∈ caligraphic_U that is known to generate a probability density path pt𝒫subscript𝑝𝑡𝒫p_{t}\in\mathcal{P}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_P over {\mathcal{M}}caligraphic_M satisfying p0=psubscript𝑝0𝑝p_{0}=pitalic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_p and p1=qsubscript𝑝1𝑞p_{1}=qitalic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_q. On a manifold endowed with a Riemannian metric g𝑔gitalic_g, the Flow Matching objective compares the tangent vectors vt(x),ut(x)Txsubscript𝑣𝑡𝑥subscript𝑢𝑡𝑥subscript𝑇𝑥v_{t}(x),u_{t}(x)\in T_{x}{\mathcal{M}}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∈ italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M using the Riemannian metric g𝑔gitalic_g at that tangent space:

RFM(θ)=𝔼t,pt(x)vt(x)ut(x)g2subscriptRFM𝜃subscript𝔼𝑡subscript𝑝𝑡𝑥superscriptsubscriptnormsubscript𝑣𝑡𝑥subscript𝑢𝑡𝑥𝑔2{\mathcal{L}}_{\scriptscriptstyle\text{RFM}}(\theta)=\mathbb{E}_{t,p_{t}(x)}% \left\|v_{t}(x)-u_{t}(x)\right\|_{g}^{2}caligraphic_L start_POSTSUBSCRIPT RFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_t , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (3)

where t𝒰[0,1]similar-to𝑡𝒰01t\sim{\mathcal{U}}[0,1]italic_t ∼ caligraphic_U [ 0 , 1 ], the uniform distribution over [0,1]01[0,1][ 0 , 1 ].

Probability path construction.

Riemannian Flow Matching therefore requires coming up with a probability density path pt𝒫subscript𝑝𝑡𝒫p_{t}\in\mathcal{P}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_P, t[0,1]𝑡01t\in[0,1]italic_t ∈ [ 0 , 1 ] that satisfies the boundary conditions

p0=p,p1=qformulae-sequencesubscript𝑝0𝑝subscript𝑝1𝑞p_{0}=p,\qquad p_{1}=qitalic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_p , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_q (4)

and a corresponding vector field (VF) ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) which generates pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) from p𝑝pitalic_p in the sense of equation 2. One way to construct such a pair is to create per-sample conditional probability paths pt(x|x1)subscript𝑝𝑡conditional𝑥subscript𝑥1p_{t}(x|x_{1})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) satisfying

p0(x|x1)=p(x),p1(x|x1)δx1(x),formulae-sequencesubscript𝑝0conditional𝑥subscript𝑥1𝑝𝑥subscript𝑝1conditional𝑥subscript𝑥1subscript𝛿subscript𝑥1𝑥p_{0}(x|x_{1})=p(x),\quad p_{1}(x|x_{1})\approx\delta_{x_{1}}(x),italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_p ( italic_x ) , italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ≈ italic_δ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) , (5)

where δx1(x)subscript𝛿subscript𝑥1𝑥\delta_{x_{1}}(x)italic_δ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) is the Dirac distribution over {\mathcal{M}}caligraphic_M centered at x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. One can then define pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) as the marginalization of these conditional probability paths over q(x1)𝑞subscript𝑥1q(x_{1})italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ).

pt(x)=pt(x|x1)q(x1)𝑑volx1,subscript𝑝𝑡𝑥subscriptsubscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1differential-dsubscriptvolsubscript𝑥1p_{t}(x)=\int_{\mathcal{M}}p_{t}(x|x_{1})q(x_{1})d\mathrm{vol}_{x_{1}},italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , (6)

which satisfies equation 4 by construction. It was then proposed by Lipman et al. (2023)—which we verify for the manifold setting—to define ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) as the “marginalization” of conditional vector fields ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) that generates pt(x|x1)subscript𝑝𝑡conditional𝑥subscript𝑥1p_{t}(x|x_{1})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (in the sense detailed in Section 2),

ut(x)=ut(x|x1)pt(x|x1)q(x1)pt(x)𝑑volx1,subscript𝑢𝑡𝑥subscriptsubscript𝑢𝑡conditional𝑥subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1subscript𝑝𝑡𝑥differential-dsubscriptvolsubscript𝑥1u_{t}(x)=\int_{\mathcal{M}}u_{t}(x|x_{1})\frac{p_{t}(x|x_{1})q(x_{1})}{p_{t}(x% )}d\mathrm{vol}_{x_{1}},italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , (7)

which provably generates pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ). However, directly plugging ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) into equation 3 is intractable as computing ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) is intractable.

Riemmanian Conditional Flow Matching.

A key insight from Lipman et al. (2023) is that when the targets ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and utsubscript𝑢𝑡u_{t}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are defined as in equations 6 and 7, the FM objective is equivalent to the following Conditional Flow Matching objective,

RCFM(θ)=𝔼t,q(x1),pt(x|x1)vt(x)ut(x|x1)g2{\mathcal{L}}_{\scriptscriptstyle\text{RCFM}}(\theta)=\mathbb{E}_{\begin{% subarray}{c}t,q(x_{1}),p_{t}(x|x_{1})\end{subarray}}\left\|v_{t}(x)-u_{t}(x|x_% {1})\right\|_{g}^{2}caligraphic_L start_POSTSUBSCRIPT RCFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (8)

as long as ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is a vector field that generates pt(x|x1)subscript𝑝𝑡conditional𝑥subscript𝑥1p_{t}(x|x_{1})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) from p𝑝pitalic_p.

Algorithm 1 Riemannian CFM
0:   base p𝑝pitalic_p, target q𝑞qitalic_q, scheduler κ𝜅\kappaitalic_κ
  Initialize parameters θ𝜃\thetaitalic_θ of vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
  while not converged do
     sample time t𝒰(0,1)similar-to𝑡𝒰01t\sim{\mathcal{U}}(0,1)italic_t ∼ caligraphic_U ( 0 , 1 )
     sample training example x1qsimilar-tosubscript𝑥1𝑞x_{1}\sim qitalic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ italic_q
     sample noise x0psimilar-tosubscript𝑥0𝑝x_{0}\sim pitalic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p
     if simple geometry then
        xt=𝚎𝚡𝚙x1(κ(t)𝚕𝚘𝚐x1(x0))subscript𝑥𝑡subscript𝚎𝚡𝚙subscript𝑥1𝜅𝑡subscript𝚕𝚘𝚐subscript𝑥1subscript𝑥0x_{t}=\texttt{exp}_{x_{1}}(\kappa(t)\texttt{log}_{x_{1}}(x_{0}))italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = exp start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_κ ( italic_t ) log start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) )
     else if general geometry then
        xt=solve_ODE([0,t],x0,ut(x|x1))subscript𝑥𝑡solve_ODE0𝑡subscript𝑥0subscript𝑢𝑡conditional𝑥subscript𝑥1x_{t}=\texttt{solve\_ODE}([0,t],x_{0},u_{t}(x|x_{1}))italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = solve_ODE ( [ 0 , italic_t ] , italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) )
     end if
     (θ)=vt(xt;θ)x˙tg2𝜃superscriptsubscriptnormsubscript𝑣𝑡subscript𝑥𝑡𝜃subscript˙𝑥𝑡𝑔2\ell(\theta)=\left\|v_{t}(x_{t};\theta)-\dot{x}_{t}\right\|_{g}^{2}roman_ℓ ( italic_θ ) = ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ ) - over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
     θ=optimizer_step((θ))𝜃optimizer_step𝜃\theta=\texttt{optimizer\_step}(\ell(\theta))italic_θ = optimizer_step ( roman_ℓ ( italic_θ ) )
  end while

To simplify this loss, consider the conditional flow, which we denote via the shorthand,

xt=ψt(x0|x1),subscript𝑥𝑡subscript𝜓𝑡conditionalsubscript𝑥0subscript𝑥1x_{t}=\psi_{t}(x_{0}|x_{1}),italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , (9)

defined as the solution to the ODE in equation 1 with the VF ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and the initial condition ψ0(x0|x1)=x0subscript𝜓0conditionalsubscript𝑥0subscript𝑥1subscript𝑥0\psi_{0}(x_{0}|x_{1})=x_{0}italic_ψ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Furthermore, since sampling from pt(x|x1)subscript𝑝𝑡conditional𝑥subscript𝑥1p_{t}(x|x_{1})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) can be done with ψt(x0|x1)subscript𝜓𝑡conditionalsubscript𝑥0subscript𝑥1\psi_{t}(x_{0}|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), where x0p(x0)similar-tosubscript𝑥0𝑝subscript𝑥0x_{0}\sim p(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), we can reparametrize equation 8 as

RCFM(θ)=𝔼t,q(x1),p(x0)vt(xt)x˙tg2subscriptRCFM𝜃subscript𝔼𝑡𝑞subscript𝑥1𝑝subscript𝑥0superscriptsubscriptnormsubscript𝑣𝑡subscript𝑥𝑡subscript˙𝑥𝑡𝑔2{\mathcal{L}}_{\scriptscriptstyle\text{RCFM}}(\theta)=\mathbb{E}_{t,q(x_{1}),p% (x_{0})}\left\|v_{t}(x_{t})-\dot{x}_{t}\right\|_{g}^{2}caligraphic_L start_POSTSUBSCRIPT RCFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (10)

where x˙t=d/dtxt=ut(xt|x1)subscript˙𝑥𝑡𝑑𝑑𝑡subscript𝑥𝑡subscript𝑢𝑡conditionalsubscript𝑥𝑡subscript𝑥1\dot{x}_{t}=\nicefrac{{d}}{{dt}}\;x_{t}=u_{t}(x_{t}|x_{1})over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = / start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ).

Riemannian Conditional Flow Matching (RCFM) has three requirements: a parametric vector field vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that outputs vectors on the tangent planes, the use of the appropriate Riemannian metric g\left\|\cdot\right\|_{g}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, and the design of a (computationally tractable) conditional flow ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) whose probability path satisfies the boundaries conditions in equation 5. We discuss this last point in the next section. Generally, compared to existing methods for training generative models on manifolds, RCFM is both simple and highly scalable; the training procedure is summarized in Algorithm 1 and a detailed comparison can be found in Appendix D.

3.2 Constructing Conditional Flows through Premetrics

We discuss the construction of conditional flows ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) on {\mathcal{M}}caligraphic_M that concentrate all mass at x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT at time t=1𝑡1t=1italic_t = 1; Figure 2 provides an illustration. This ensures that equation 5 will hold (regardless of the choice of p𝑝pitalic_p) since all points are mapped to x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT at time t=1𝑡1t=1italic_t = 1, namely

ψ1(x|x1)=x1, for all x.formulae-sequencesubscript𝜓1conditional𝑥subscript𝑥1subscript𝑥1 for all 𝑥\psi_{1}(x|x_{1})=x_{1},\text{ for all }x\in{\mathcal{M}}.italic_ψ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , for all italic_x ∈ caligraphic_M . (11)
Refer to caption
Figure 2: The conditional vector field ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined in equation 13 transports all points xx1𝑥subscript𝑥1x\neq x_{1}italic_x ≠ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT at exactly t=1𝑡1t=1italic_t = 1.

On general manifolds, directly constructing ψtsubscript𝜓𝑡\psi_{t}italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that satisfies equation 11 can be overly cumbersome. Alternatively, we propose an approach based on designing a premetric instead, which has simple properties that, when satisfied, characterize conditional flows which satisfy equation 11. Specifically, we define a premetric as d:×:d\textrm{d}:{\mathcal{M}}\times{\mathcal{M}}\rightarrow\mathbb{R}d : caligraphic_M × caligraphic_M → blackboard_R satisfying:

  1. 1.

    Non-negative: d(x,y)0d𝑥𝑦0\textrm{d}(x,y)\geq 0d ( italic_x , italic_y ) ≥ 0 for all x,y𝑥𝑦x,y\in{\mathcal{M}}italic_x , italic_y ∈ caligraphic_M.

  2. 2.

    Positive: d(x,y)=0d𝑥𝑦0\textrm{d}(x,y)=0d ( italic_x , italic_y ) = 0 iff x=y𝑥𝑦x=yitalic_x = italic_y.

  3. 3.

    Non-degenerate: d(x,y)0d𝑥𝑦0\nabla\textrm{d}(x,y)\neq 0∇ d ( italic_x , italic_y ) ≠ 0 iff xy𝑥𝑦x\neq yitalic_x ≠ italic_y.

We use as convention d(x,y)=xd(x,y)d𝑥𝑦subscript𝑥d𝑥𝑦\nabla\textrm{d}(x,y)=\nabla_{x}\textrm{d}(x,y)∇ d ( italic_x , italic_y ) = ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT d ( italic_x , italic_y ). Such a premetric denotes the closeness of a point x𝑥xitalic_x to x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, and we aim to design a conditional flow ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) that monotonically decreases this premetric. That is, given a monotonically decreasing differentiable function κ(t)𝜅𝑡\kappa(t)italic_κ ( italic_t ) satisfying κ(0)=1𝜅01\kappa(0)=1italic_κ ( 0 ) = 1 and κ(1)=0𝜅10\kappa(1)=0italic_κ ( 1 ) = 0, we want to find a ψtsubscript𝜓𝑡\psi_{t}italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that decreases d(,x1)dsubscript𝑥1\textrm{d}(\cdot,x_{1})d ( ⋅ , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) according to

d(ψt(x0|x1),x1)=κ(t)d(x0,x1),dsubscript𝜓𝑡conditionalsubscript𝑥0subscript𝑥1subscript𝑥1𝜅𝑡dsubscript𝑥0subscript𝑥1\textrm{d}(\psi_{t}(x_{0}|x_{1}),x_{1})=\kappa(t)\textrm{d}(x_{0},x_{1}),d ( italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_κ ( italic_t ) d ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , (12)

Here κ(t)𝜅𝑡\kappa(t)italic_κ ( italic_t ) acts as a scheduler that determines the rate at which d(|x1)\textrm{d}(\cdot|x_{1})d ( ⋅ | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) decreases. Note that at t=1𝑡1t=1italic_t = 1, we necessarily satisfy equation 11 since x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is the unique solution to d(|x1)=0\textrm{d}(\cdot|x_{1})=0d ( ⋅ | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = 0 due to the “positive” property of the premetric. Our next theorem shows that ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) satisfying equation 12 results in the following vector field,

ut(x|x1)=dlogκ(t)dtd(x,x1)d(x,x1)d(x,x1)g2,subscript𝑢𝑡conditional𝑥subscript𝑥1𝑑𝜅𝑡𝑑𝑡d𝑥subscript𝑥1d𝑥subscript𝑥1subscriptsuperscriptnormd𝑥subscript𝑥12𝑔u_{t}(x|x_{1})=\frac{d\log\kappa(t)}{dt}\textrm{d}(x,x_{1})\frac{\nabla\textrm% {d}(x,x_{1})}{\left\|\nabla\textrm{d}(x,x_{1})\right\|^{2}_{g}},italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = divide start_ARG italic_d roman_log italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG ∇ d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT end_ARG , (13)

The “non-degenerate” property guarantees this conditional vector field is defined everywhere xx1𝑥subscript𝑥1x\neq x_{1}italic_x ≠ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Theorem 3.1.

The flow ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined by the vector field ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) in equation 13 satisfies equation 12, and therefore also equation 11. Conversely, out of all conditional vector fields that satisfy equation 12, this ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is the minimal norm solution.

A more concise statement and full proof of this result can be found in Appendix B. Here we provide proof for the first part: Consider the scalar function a(t)=d(xt,x1)𝑎𝑡dsubscript𝑥𝑡subscript𝑥1a(t)=\textrm{d}(x_{t},x_{1})italic_a ( italic_t ) = d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), where xt=ψt(x|x1)subscript𝑥𝑡subscript𝜓𝑡conditional𝑥subscript𝑥1x_{t}=\psi_{t}(x|x_{1})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is the flow defined with the VF in equation 13. Differentiation w.r.t. time gives

ddta(t)𝑑𝑑𝑡𝑎𝑡\displaystyle\frac{d}{dt}a(t)divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_a ( italic_t ) =d(xt,x1),x˙tg=d(xt,x1),u(xt|x1)g=dlogκ(t)dta(t),absentsubscriptdsubscript𝑥𝑡subscript𝑥1subscript˙𝑥𝑡𝑔subscriptdsubscript𝑥𝑡subscript𝑥1𝑢conditionalsubscript𝑥𝑡subscript𝑥1𝑔𝑑𝜅𝑡𝑑𝑡𝑎𝑡\displaystyle=\left\langle\nabla\textrm{d}(x_{t},x_{1}),\dot{x}_{t}\right% \rangle_{g}=\left\langle\nabla\textrm{d}(x_{t},x_{1}),u(x_{t}|x_{1})\right% \rangle_{g}=\frac{d\log\kappa(t)}{dt}a(t),= ⟨ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = ⟨ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_u ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = divide start_ARG italic_d roman_log italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG italic_a ( italic_t ) ,

The solution of this ODE is a(t)=κ(t)d(x,x1)𝑎𝑡𝜅𝑡d𝑥subscript𝑥1a(t)=\kappa(t)\textrm{d}(x,x_{1})italic_a ( italic_t ) = italic_κ ( italic_t ) d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), which can be verified through substitution, and hence proves d(xt,x1)=κ(t)d(x,x1)dsubscript𝑥𝑡subscript𝑥1𝜅𝑡d𝑥subscript𝑥1\textrm{d}(x_{t},x_{1})=\kappa(t)\textrm{d}(x,x_{1})d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_κ ( italic_t ) d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). Intuitively, ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is the minimal norm solution since it does not contain orthogonal directions that do not decrease the premetric.

A simple choice we make in this paper for the scheduler is κ(t)=1t𝜅𝑡1𝑡\kappa(t)=1-titalic_κ ( italic_t ) = 1 - italic_t, resulting in a conditional flow that linearly decreases the premetric between xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Using this, we arrive at a more explicit form of the RCFM objective,

RCFM(θ)=𝔼t,q(x1),p(x0)vt(xt)+d(x0,x1)d(xt,x1)d(xt,x1)g2g2.subscriptRCFM𝜃subscript𝔼𝑡𝑞subscript𝑥1𝑝subscript𝑥0subscriptsuperscriptnormsubscript𝑣𝑡subscript𝑥𝑡dsubscript𝑥0subscript𝑥1dsubscript𝑥𝑡subscript𝑥1superscriptsubscriptnormdsubscript𝑥𝑡subscript𝑥1𝑔22𝑔\displaystyle{\mathcal{L}}_{\scriptscriptstyle\text{RCFM}}(\theta)=\mathbb{E}_% {\begin{subarray}{c}t,q(x_{1}),p(x_{0})\end{subarray}}\left\|v_{t}(x_{t})+% \textrm{d}(x_{0},x_{1})\frac{\nabla\textrm{d}(x_{t},x_{1})}{\left\|\nabla% \textrm{d}(x_{t},x_{1})\right\|_{g}^{2}}\right\|^{2}_{g}.caligraphic_L start_POSTSUBSCRIPT RCFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + d ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT . (15)

For general manifolds {\mathcal{M}}caligraphic_M and premetrics d, training with Riemmanian CFM will require simulation in order to solve for xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, though it does not need to differentiate through xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. However, on simple geometries RCFM can become completely simulation-free by choosing the premetric to be the geodesic distance, as we discuss next.

Geodesic distance.

A natural choice for the premetric d(x,y)d𝑥𝑦\textrm{d}(x,y)d ( italic_x , italic_y ) over a Riemannian manifold {\mathcal{M}}caligraphic_M is the geodesic distance dg(x,y)subscriptd𝑔𝑥𝑦\textrm{d}_{g}(x,y)d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ). Firstly, we note that when using geodesic distance as our choice of premetric, the flow ψt(x0|x1)subscript𝜓𝑡conditionalsubscript𝑥0subscript𝑥1\psi_{t}(x_{0}|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )—since it is the minimal norm solution—is equivalent to the geodesic path, i.e., shortest path, connecting x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Proposition 3.2.

Consider a complete, connected smooth Riemannian manifold (,g)𝑔({\mathcal{M}},g)( caligraphic_M , italic_g ) with geodesic distance 𝑑g(x,y)subscript𝑑𝑔𝑥𝑦\textrm{d}_{g}(x,y)d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ). In case 𝑑(x,y)=𝑑g(x,y)𝑑𝑥𝑦subscript𝑑𝑔𝑥𝑦\textrm{d}(x,y)=\textrm{d}_{g}(x,y)d ( italic_x , italic_y ) = d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) then xt=ψt(x0|x1)subscript𝑥𝑡subscript𝜓𝑡conditionalsubscript𝑥0subscript𝑥1x_{t}=\psi_{t}(x_{0}|x_{1})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined by the conditional VF in equation 13 with the scheduler κ(t)=1t𝜅𝑡1𝑡\kappa(t)=1-titalic_κ ( italic_t ) = 1 - italic_t is a constant speed geodesic connecting x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

This makes it easy to compute xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT on simple manifolds, which we define as manifolds with closed-form geodesics, e.g., Euclidean space, the hypersphere, the hyperbolic space, the high-dimensional torus, and some matrix Lie Groups. In particular, the geodesic connecting x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT can be expressed in terms of the exponential and logarithm maps,

xt=expx1(κ(t)logx1(x0)),t[0,1].formulae-sequencesubscript𝑥𝑡subscriptsubscript𝑥1𝜅𝑡subscriptsubscript𝑥1subscript𝑥0𝑡01x_{t}=\exp_{x_{1}}(\kappa(t)\log_{x_{1}}(x_{0})),\quad t\in[0,1].italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_exp start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_κ ( italic_t ) roman_log start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) , italic_t ∈ [ 0 , 1 ] . (16)

This formula can simply be plugged into equation 10, resulting in a highly scalable training objective. A list of simple manifolds that we consider can be found in Table 5.

Euclidean geometry.

With Euclidean geometry =nsuperscript𝑛{\mathcal{M}}=\mathbb{R}^{n}caligraphic_M = blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, and with standard Euclidean norm d(x,y)=xy2d𝑥𝑦subscriptnorm𝑥𝑦2\textrm{d}(x,y)=\left\|x-y\right\|_{2}d ( italic_x , italic_y ) = ∥ italic_x - italic_y ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, the conditional VF (equation 13) with scheduler κ(t)=1t𝜅𝑡1𝑡\kappa(t)=1-titalic_κ ( italic_t ) = 1 - italic_t reduces to the VF used by Lipman et al. (2023), ut(x|x1)=x1x1tsubscript𝑢𝑡conditional𝑥subscript𝑥1subscript𝑥1𝑥1𝑡u_{t}(x|x_{1})=\tfrac{x_{1}-x}{1-t}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = divide start_ARG italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_x end_ARG start_ARG 1 - italic_t end_ARG, and the RCFM objective takes the form

RCFM(θ)=𝔼t,q(x1),p(x0)vt(xt)+x0x122,subscriptRCFM𝜃subscript𝔼𝑡𝑞subscript𝑥1𝑝subscript𝑥0superscriptsubscriptnormsubscript𝑣𝑡subscript𝑥𝑡subscript𝑥0subscript𝑥122{\mathcal{L}}_{\scriptscriptstyle\text{RCFM}}(\theta)=\mathbb{E}_{t,q(x_{1}),p% (x_{0})}\left\|v_{t}(x_{t})+x_{0}-x_{1}\right\|_{2}^{2},caligraphic_L start_POSTSUBSCRIPT RCFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,

which coincides with the Euclidean case of Flow Matching presented in prior works (Lipman et al., 2023; Liu et al., 2023).

3.3 Spectral Distances on General Geometries

Geodesics can be difficult to compute efficiently for general geometries, especially since it needs to be computed for any possible pair of points. Hence, we propose using premetrics that can be computed quickly for any pair of points on {\mathcal{M}}caligraphic_M contingent on a one-time upfront cost. In particular, for general Riemannian manifolds, we consider the use of approximate spectral distances as an alternative to the geodesic distance. Spectral distances actually offer some benefits over the geodesic distance such as robustness to topological noise, smoothness, and are globally geometry-aware (Lipman et al., 2010). Note however, that spectral distances do not define minimizing (geodesic) paths, and will require simulation of ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) in order to compute conditional flows xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Let φi::subscript𝜑𝑖\varphi_{i}:{\mathcal{M}}\rightarrow\mathbb{R}italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT : caligraphic_M → blackboard_R be the eigenfunctions of the Laplace-Beltrami operator ΔgsubscriptΔ𝑔\Delta_{g}roman_Δ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT over {\mathcal{M}}caligraphic_M with corresponding eigenvalues λisubscript𝜆𝑖\lambda_{i}italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, i.e., they satisfy Δgφi=λiφisubscriptΔ𝑔subscript𝜑𝑖subscript𝜆𝑖subscript𝜑𝑖\Delta_{g}\varphi_{i}=\lambda_{i}\varphi_{i}roman_Δ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, for i=1,2,𝑖12i=1,2,\dotsitalic_i = 1 , 2 , …, then spectral distances are of the form

dw(x,y)2=i=1w(λi)(φi(x)φi(y))2,subscriptd𝑤superscript𝑥𝑦2superscriptsubscript𝑖1𝑤subscript𝜆𝑖superscriptsubscript𝜑𝑖𝑥subscript𝜑𝑖𝑦2\textrm{d}_{w}(x,y)^{2}=\sum_{i=1}^{\infty}w(\lambda_{i})\left(\varphi_{i}(x)-% \varphi_{i}(y)\right)^{2},d start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x , italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT italic_w ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) - italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (17)

where w:+:𝑤subscriptw:\mathbb{R}\rightarrow\mathbb{R}_{+}italic_w : blackboard_R → blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT is some monotonically decreasing weighting function. Popular instances of spectral distances include:

  1. 1.

    Diffusion Distance Coifman & Lafon (2006): w(λ)=exp(2τλ)𝑤𝜆2𝜏𝜆w(\lambda)=\exp(-2\tau\lambda)italic_w ( italic_λ ) = roman_exp ( - 2 italic_τ italic_λ ), with a parameter τ𝜏\tauitalic_τ.

  2. 2.

    Biharmonic Distance Lipman et al. (2010): w(λ)=λ2𝑤𝜆superscript𝜆2w(\lambda)=\lambda^{-2}italic_w ( italic_λ ) = italic_λ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT.

In practice, we truncate the infinite series in equation 17 to the smallest k𝑘kitalic_k eigenvalues. These k𝑘kitalic_k eigenfunctions can be numerically solved as a one-time preprocessing cost prior to training. Furthermore, we note that using an approximation of the spectral distance with finite k𝑘kitalic_k is sufficient for satisfying the properties of the premetric, leading to no bias in the training procedure. Lastly, we consider manifolds with boundaries and show that solving eigenfunctions using the natural, or Neumann, boundary conditions ensures that the resulting ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) does not leave the interior of the manifold. Detailed discussions on these points above can be found in Appendix G. Figure 3 visualizes contour plots of these spectral distances for manifolds with non-trivial curvatures.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Geodesic
Refer to caption
Biharmonic
Refer to caption
Diffusion τ=1𝜏1\tau=1italic_τ = 1
Refer to caption
Diffusion τ=14𝜏14\tau=\tfrac{1}{4}italic_τ = divide start_ARG 1 end_ARG start_ARG 4 end_ARG
Refer to caption
Diffusion τ=110𝜏110\tau=\tfrac{1}{10}italic_τ = divide start_ARG 1 end_ARG start_ARG 10 end_ARG
Figure 3: Contour plots of geodesic and spectral distances (to a source point) on general manifolds. Geodesics are expensive to compute online and are globally non-smooth. The biharmonic distance behaves smoothly while the diffusion distance requires careful tuning of the hyperparameter τ𝜏\tauitalic_τ.

4 Related Work

Table 2: Test NLL on Earth and climate science datasets. Standard deviation estimated over 5 runs.
Volcano Earthquake Flood Fire
Dataset size (train + val + test) 827 6120 4875 12809
CNF-based
   Riemannian CNF (Mathieu & Nickel, 2020) -6.05±plus-or-minus\pm±0.61 0.14±plus-or-minus\pm±0.23 1.11±plus-or-minus\pm±0.19 -0.80±plus-or-minus\pm±0.54
   Moser Flow (Rozen et al., 2021) -4.21±plus-or-minus\pm±0.17 -0.16±plus-or-minus\pm±0.06 0.57±plus-or-minus\pm±0.10 -1.28±plus-or-minus\pm±0.05
   CNF Matching (Ben-Hamu et al., 2022) -2.38±plus-or-minus\pm±0.17 -0.38±plus-or-minus\pm±0.01 0.25±plus-or-minus\pm±0.02 -1.40±plus-or-minus\pm±0.02
   Riemannian Score-Based (De Bortoli et al., 2022) -4.92±plus-or-minus\pm±0.25 -0.19±plus-or-minus\pm±0.07 0.48±plus-or-minus\pm±0.17 -1.33±plus-or-minus\pm±0.06
ELBO-based
   Riemannian Diffusion Model (Huang et al., 2022) -6.61±plus-or-minus\pm±0.96 -0.40±plus-or-minus\pm±0.05 0.43±plus-or-minus\pm±0.07 -1.38±plus-or-minus\pm±0.05
\cdashline1-7 Ours
   Riemannian Flow Matching w/ Geodesic -7.93±plus-or-minus\pm±1.67 -0.28±plus-or-minus\pm±0.08 0.42±plus-or-minus\pm±0.05 -1.86±plus-or-minus\pm±0.11

Deep generative models on Riemannian manifolds.

Some initial works suggested constructing normalizing flows that map between manifolds and Euclidean spaces of the same intrinsic dimension (Gemici et al., 2016; Rezende et al., 2020; Bose et al., 2020), often relying on the tangent space at some pre-specified origin. However, this approach is problematic when the manifold is not homeomorphic to Euclidean space, resulting in both theoretical and numerical issues. On the other hand, continuous-time models such as continuous normalizing flows bypass such topogical constraints and flow directly on the manifold itself. To this end, a number of works have formulated continuous normalizing flows on simple manifolds (Mathieu & Nickel, 2020; Lou et al., 2020; Falorsi, 2020), but these rely on maximum likelihood for training, a costly simulation-based procedure. More recently, simulation-free training methods for continuous normalizing flows on manifolds have been proposed (Rozen et al., 2021; Ben-Hamu et al., 2022); however, these scale poorly to high dimensions and do not adapt to general geometries.

Riemannian diffusion models.

With the influx of diffusion models that allow efficient simulation-free training on Euclidean space (Ho et al., 2020; Song et al., 2020b), multiple works have attempted to adopt diffusion models to manifolds (Mathieu & Nickel, 2020; Huang et al., 2022). However, due to the reliance on stochastic differential equations (SDE) and denoising score matching (Vincent, 2011), these approaches necessitate in-training simulation and approximations when applied to non-Euclidean manifolds.

First and foremost, they lose the simulation-free sampling of xtpt(x|x1)similar-tosubscript𝑥𝑡subscript𝑝𝑡conditional𝑥subscript𝑥1x_{t}\sim p_{t}(x|x_{1})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) that is offered in the Euclidean regime; this is because the manifold analog of the Ornstein–Uhlenbeck SDE does not have closed-form solutions. Hence, diffusion-based methods have to resort to simulated random walks as a noising process even on simple manifolds (De Bortoli et al., 2022; Huang et al., 2022).

Furthermore, even on simple manifolds, the conditional score function is not known analytically, so De Bortoli et al. (2022) proposed approximating the conditional score function with either an eigenfunction expansion or Varhadan’s heat-kernel approximation. These approximations lead to biased gradients in the denoising score matching framework. We find that the heat-kernel approximations can potentially be extremely biased, even with hundreds with eigenfunctions (see Figure 11). In contrast, we show in Figure 11 that our framework can satisfy all premetric requirements even with a small number of eigenfunctions for the spectral distance approximation—hence guaranteeing that the optimal model distribution is the data distribution. See detailed discussions in Section G.1.

A way to bypass the conditional score function is to use implicit score matching (Hyvärinen & Dayan, 2005), which Huang et al. (2022) adopts for the manifold case, but this instead requires divergence computation of the large neural nets during training. Using the Hutchinson estimator (Hutchinson, 1989; Skilling, 1989; Grathwohl et al., 2019; Song et al., 2020a) for divergence estimation results in a more scalable algorithm, but the variance of the Hutchinson estimator scales poorly with dimension (Hutchinson, 1989) and is further exacerbated on non-Euclidean manifolds (Mathieu & Nickel, 2020).

Finally, the use of SDEs as a noising process requires carefully constructing suitable reverse-time processes that approximate either just the probability path (Anderson, 1982) or the actual sample trajectories (Li et al., 2020), whereas ODE solutions are generally well-defined in both forward and reverse directions (Murray & Miller, 2013).

In contrast to these methods, Riemannian Flow Matching is simulation-free on simple geometries, has exact conditional vector fields, and does not require divergence computation during training. These properties are summarized in Table 1, and a detailed comparison of algorithmic differences to diffusion-based approaches is presented in Appendix D. Lastly, for general Riemannian manifolds we show that the design of a relatively simple premetric is sufficient, allowing the use of general distance functions that don’t satisfy all axioms of a metric—such as approximate spectral distances with finite truncation—going beyond what is currently possible with existing Riemannian diffusion methods.

Table 3: Test NLL on protein datasets. Standard deviation estimated over 5 runs.
General (2D) Glycine (2D) Proline (2D) Pre-Pro (2D) RNA (7D)
Dataset size (train + val + test) 138208 13283 7634 6910 9478
Mixture of Power Spherical (Huang et al., 2022) 1.15±plus-or-minus\pm±0.002 2.08±plus-or-minus\pm±0.009 0.27±plus-or-minus\pm±0.008 1.34±plus-or-minus\pm±0.019 4.08±plus-or-minus\pm±0.368
Riemannian Diffusion Model (Huang et al., 2022) 1.04±plus-or-minus\pm±0.012 1.97±plus-or-minus\pm±0.012 0.12±plus-or-minus\pm±0.011 1.24±plus-or-minus\pm±0.004 -3.70±plus-or-minus\pm±0.592
\cdashline1-8 Riemannian Flow Matching w/ Geodesic 1.01±plus-or-minus\pm±0.025 1.90±plus-or-minus\pm±0.055 0.15±plus-or-minus\pm±0.027 1.18±plus-or-minus\pm±0.055 -5.20±plus-or-minus\pm±0.067

Euclidean Flow Matching.

Riemannian Flow Matching is built on top of recent simulation-free methods that work with ODEs instead of SDEs, regressing directly onto generating vector fields instead of score functions (Lipman et al., 2023; Albergo & Vanden-Eijnden, 2023; Liu et al., 2023; Neklyudov et al., 2022), resulting in an arguably simpler approach to continuous-time generative modeling without the intricacies of dealing with stochastic differential equations. In particular, Lipman et al. (2023) shows that this approach encompasses and broadens the probability paths used by diffusion models while remaining simulation-free; Albergo & Vanden-Eijnden (2023) discusses an interpretation based on the use of interpolants—equivalent to our conditional flows ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), except we also make explicit the construction of the marginal probability path pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) and vector field ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ); Liu et al. (2023) shows that repeatedly fitting to a model’s own samples leads to straighter trajectories; and Neklyudov et al. (2022) formulates an implicit objective when ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) is a gradient field.

5 Experiments

We consider data from earth and climate science, protein structures, high-dimensional tori, complicated synthetic distributions on general closed manifolds, and distributions on maze-shaped manifolds that require navigation across non-trivial boundaries. Details regarding training setup is discussed in Appendix H. Due to space constraints, additional experiments on hyperbolic manifold and a manifold over matrices can be found in Appendix J, which are endowed with nontrivial Riemannian metrics. Table 5 provides all details of the simple manifolds and their geometries. Details regarding the more complex mesh manifolds can be found in the open source code, which we release for reproducibility111https://github.com/facebookresearch/riemannian-fm.

Earth and climate science datasets on the sphere.

We make use of the publicly sourced datasets (NOAA, 2020a; b; Brakenridge, 2017; EOSDIS, 2020) compiled by Mathieu & Nickel (2020). These data points lie on the 2-D sphere, a simple manifold with closed form exponential and logarithm maps. We therefore stick to the geodesic distance and compute geodesics in closed form as in equation 16. Table 2 shows the results alongside prior methods. We achieve a sizable improvement over prior works on the volcano and fire datasets which have highly concentrated regions that require a high fidelity. Figure 8 shows the density of our trained models.

Eigenfunction

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

Model

Refer to caption
Bunny (k𝑘kitalic_k=10)
Refer to caption
Bunny (k𝑘kitalic_k=50)
Refer to caption
Bunny (k𝑘kitalic_k=100)
Refer to caption
Spot (k𝑘kitalic_k=10)
Refer to caption
Spot (k𝑘kitalic_k=50)
Refer to caption
Spot (k𝑘kitalic_k=100)
Figure 4: Visualization of (top) the eigenfunctions that were used to construct target distributions, and (bottom) the learned density & samples from trained models with the Biharmonic distance.

Protein datasets on the torus.

We make use of the preprocessed protein (Lovell et al., 2003) and RNA (Murray et al., 2003) datasets compiled by Huang et al. (2022). These datasets represent torsion angles and can be represented on the 2D and 7D torus. We represent the data on a flat torus, which is isometric to the product of 1-D spheres used by prior works (Huang et al., 2022; De Bortoli et al., 2022) and result in densities that are directly comparable due to this isometry. Results are displayed in Table 3, and we show learned densities of the protein datasets in Figure 9. Compared to Huang et al. (2022), we see a significant gain in performance particularly on the higher dimensional 7D torus, due to the higher complexity of the dataset.

Scaling to high dimensions.

We next consider the scalability of our method in the case of high-dimensional tori, following the exact setup in De Bortoli et al. (2022). We compare to Moser Flow (Rozen et al., 2021), which does not scale well into high dimensions, and Riemannian Score-based (De Bortoli et al., 2022) using implicit score matching (ISM).

Refer to caption
Figure 5: Riemannian Flow Matching scales incredibly well to higher dimensions as it is simulation-free and all quantities required for training are computed exactly on simple geometries such as tori. Log-likelihoods are in bits.

As shown in Table 1, this objective gets around the need to approximate conditional score functions, but it requires stochastic divergence estimation, introducing larger amounts of variance at higher dimensions. In Figure 5 we plot log-likelihood values, across these two baselines and our method with the geodesic construction. We see that our method performs steadily, with no significant drop in performance at higher dimensions since we do not have any reliance on approximations.

Manifolds with non-trivial curvature.

We next experiment with general closed manifolds using spectral distances as described in Section 3.3. Specifically, we experiment on manifolds described by triangular meshes. For meshes, computing geodesic distances on-the-fly is too expensive for our use case, which requires hundreds of evaluations per training iteration. Fast approximations to the geodesic distance between two points are 𝒪(nlogn)𝒪𝑛𝑛\mathcal{O}(n\log n)caligraphic_O ( italic_n roman_log italic_n ) (Kimmel & Sethian, 1998), while exact geodesic distances require 𝒪(n2logn)𝒪superscript𝑛2𝑛\mathcal{O}(n^{2}\log n)caligraphic_O ( italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log italic_n ) (Surazhsky et al., 2005), where n𝑛nitalic_n is the number of edges. On the other hand, computing spectral distances is 𝒪(k)𝒪𝑘\mathcal{O}(k)caligraphic_O ( italic_k ), where knmuch-less-than𝑘𝑛k\ll nitalic_k ≪ italic_n, i.e., it does not scale with the complexity of the manifold after the one-time preprocessing step. As our manifolds, we use the Standard Bunny (Turk & Levoy, 1994) and Spot the Cow (Crane et al., 2013). Similar to Rozen et al. (2021), we construct distributions by computing the k𝑘kitalic_k-th eigenfunction, thresholding, and then sampling proportionally to the eigenfunction. This is done on a high-resolution mesh so that the distribution is non-trivial on each triangle. Figure 4 contains visualizations of the eigenfunctions, the learned density, and samples from a trained model that transports from a uniform base distribution. We used k𝑘kitalic_k=200 eigenfunctions, which is sufficient for our method to produce high fidelity samples.

Stanford Bunny Spot the Cow
k𝑘kitalic_k=10 k𝑘kitalic_k=50 k𝑘kitalic_k=100 k𝑘kitalic_k=10 k𝑘kitalic_k=50 k𝑘kitalic_k=100
Riemannian CFM
   w/ Diffusion (τ𝜏\tauitalic_τ=1/414\nicefrac{{1}}{{4}}/ start_ARG 1 end_ARG start_ARG 4 end_ARG) 1.16±plus-or-minus\pm±0.02 1.48±plus-or-minus\pm±0.01 1.53±plus-or-minus\pm±0.01 0.87±plus-or-minus\pm±0.07 0.95±plus-or-minus\pm±0.16 1.08±plus-or-minus\pm±0.05
   w/ Biharmonic 1.06±plus-or-minus\pm±0.05 1.55±plus-or-minus\pm±0.01 1.49±plus-or-minus\pm±0.01 1.02±plus-or-minus\pm±0.06 1.08±plus-or-minus\pm±0.05 1.29±plus-or-minus\pm±0.05
Table 4: Test NLL on mesh datasets.

In Table 4, we report the test NLL of models trained using either the diffusion distance or the biharmonic distance. We had to carefully tune the diffusion distance hyperparameter τ𝜏\tauitalic_τ while the biharmonic distance was straightforward to use out-of-the-box and it has better smoothness properties (see Figure 3).

Refer to caption Refer to caption Refer to caption Refer to caption
(a) (b) (c) (d)
Figure 6: (a, c) Source (cyan) and target (yellow) distributions on a manifold with non-trivial boundaries. (b, d) Sample trajectories from a CNF model trained through RCFM with the Biharmonic distance.

Manifolds with boundaries.

Lastly, we experiment with manifolds that have boundaries. Specifically, we consider randomly generated mazes visualized in Figure 6. We set the base distribution to be a Gaussian in the middle of the maze, and set the target distribution to be a mixture of densities at corners of the maze. These mazes are represented using triangular meshes, and we use the biharmonic distance using k𝑘kitalic_k=30 eigenfunctions. Once trained, the model represents a single vector field that transports all mass from the source distribution to the target distribution with no crossing paths. We plot sample trajectories in Figure 6 (b) and (d), where it can be seen that the learned vector field avoids boundaries of the manifold and successfully navigates to different modes in the target distribution.

6 Conclusion

We propose Riemannian Flow Matching as a highly-scalable approach for training continuous normalizing flows on manifolds. Our method is completely simulation-free and introduces zero approximation errors on simple geometries that have closed-form geodesics. We also introduce benchmark problems for general manifolds and showcase for the first time, tractable training on general geometries including both closed manifolds and manifolds with boundaries.

Acknowledgements

Ricky T. Q. Chen would like to thank Chin-Wei Huang for helpful discussions. Additionally, we acknowledge the Python community (Van Rossum & Drake Jr, 1995; Oliphant, 2007) for develo** the core set of tools that enabled this work, including PyTorch (Paszke et al., 2019), PyTorch Lightning (Falcon & team, 2019), Hydra (Yadan, 2019), Jupyter (Kluyver et al., 2016), Matplotlib (Hunter, 2007), seaborn (Waskom et al., 2018), numpy (Oliphant, 2006; Van Der Walt et al., 2011), SciPy (Jones et al., 2014) pandas (McKinney, 2012), geopandas (Jordahl et al., 2020), torchdiffeq (Chen, 2018), libigl (Panozzo & Jacobson, 2014), and PyEVTK (Herrera, 2019).

References

  • Albergo & Vanden-Eijnden (2023) Michael S Albergo and Eric Vanden-Eijnden. Building normalizing flows with stochastic interpolants. International Conference on Learning Representations, 2023.
  • Anderson (1982) Brian DO Anderson. Reverse-time diffusion equation models. Stochastic Processes and their Applications, 12(3):313–326, 1982.
  • Barachant et al. (2013) Alexandre Barachant, Stéphane Bonnet, Marco Congedo, and Christian Jutten. Classification of covariance matrices using a riemannian-based kernel for bci applications. Neurocomputing, 112:172–178, 2013.
  • Belkin & Niyogi (2003) Mikhail Belkin and Partha Niyogi. Laplacian eigenmaps for dimensionality reduction and data representation. Neural computation, 15(6):1373–1396, 2003.
  • Ben-Hamu et al. (2022) Heli Ben-Hamu, Samuel Cohen, Joey Bose, Brandon Amos, Aditya Grover, Maximilian Nickel, Ricky T. Q. Chen, and Yaron Lipman. Matching normalizing flows and probability paths on manifolds. International Conference on Machine Learning, 2022.
  • Blankertz et al. (2007) Benjamin Blankertz, Guido Dornhege, Matthias Krauledat, Klaus-Robert Müller, and Gabriel Curio. The non-invasive berlin brain–computer interface: fast acquisition of effective performance in untrained subjects. NeuroImage, 37(2):539–550, 2007.
  • Bose et al. (2020) Joey Bose, Ariella Smofsky, Renjie Liao, Prakash Panangaden, and Will Hamilton. Latent variable modelling with hyperbolic normalizing flows. In Hal Daumé III and Aarti Singh (eds.), Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pp.  1045–1055. PMLR, 13–18 Jul 2020.
  • Brakenridge (2017) G.R. Brakenridge. Global active archive of large flood events. http://floodobservatory.colorado.edu/Archives/index.html, 2017. Dartmouth Flood Observatory, University of Colorado,.
  • Brunner et al. (2008) Clemens Brunner, Robert Leeb, Gernot Müller-Putz, Alois Schlögl, and Gert Pfurtscheller. Bci competition 2008–graz data set a. Institute for Knowledge Discovery (Laboratory of Brain-Computer Interfaces), Graz University of Technology, 16:1–6, 2008.
  • Chen (2018) Ricky T. Q. Chen. torchdiffeq, 2018. URL https://github.com/rtqichen/torchdiffeq.
  • Chen et al. (2018) Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. Advances in neural information processing systems, 31, 2018.
  • Coifman & Lafon (2006) Ronald R Coifman and Stéphane Lafon. Diffusion maps. Applied and computational harmonic analysis, 21(1):5–30, 2006.
  • Crane et al. (2013) Keenan Crane, Ulrich Pinkall, and Peter Schröder. Robust fairing via conformal curvature flow. ACM Transactions on Graphics (TOG), 32(4):1–10, 2013.
  • De Bortoli et al. (2022) Valentin De Bortoli, Emile Mathieu, Michael Hutchinson, James Thornton, Yee Whye Teh, and Arnaud Doucet. Riemannian score-based generative modeling. Advances in Neural Information Processing Systems, 2022.
  • Deng et al. (2022) Zhijie Deng, Jiaxin Shi, Hao Zhang, Peng Cui, Cewu Lu, and Jun Zhu. Neural eigenfunctions are structured representation learners. arXiv preprint arXiv:2210.12637, 2022.
  • EOSDIS (2020) EOSDIS. Active fire data. https://earthdata.nasa.gov/earth-observation-data/near-real-time/firms/active-fire-data, 2020. Land, Atmosphere Near real-time Capability for EOS (LANCE) system operated by NASA’s Earth Science Data and Information System (ESDIS).
  • Falcon & team (2019) William Falcon and The PyTorch Lightning team. Pytorch lightning, 2019. URL https://github.com/Lightning-AI/lightning.
  • Falorsi (2020) Luca Falorsi. Continuous normalizing flows on manifolds. PhD thesis, University of Amsterdam, 2020.
  • Gallot et al. (1990) Sylvestre Gallot, Dominique Hulin, and Jacques Lafontaine. Riemannian geometry, volume 2. Springer, 1990.
  • Gemici et al. (2016) Mevlana C Gemici, Danilo Rezende, and Shakir Mohamed. Normalizing flows on riemannian manifolds. arXiv preprint arXiv:1611.02304, 2016.
  • Grathwohl et al. (2019) Will Grathwohl, Ricky T. Q. Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. FFJORD: Free-form continuous dynamics for scalable reversible generative models. International Conference on Learning Representations, 2019.
  • Hairer (2011) Ernst Hairer. Solving differential equations on manifolds. Lecture Notes, Université de Geneve, 2011.
  • Hairer et al. (2006) Ernst Hairer, Marlis Hochbruck, Arieh Iserles, and Christian Lubich. Geometric numerical integration. Oberwolfach Reports, 3(1):805–882, 2006.
  • Herrera (2019) Paulo Herrera. Pyevtk, 2019. URL https://github.com/paulo-herrera/PyEVTK.
  • Ho et al. (2020) Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.
  • Huang et al. (2022) Chin-Wei Huang, Milad Aghajohari, Avishek Joey Bose, Prakash Panangaden, and Aaron Courville. Riemannian diffusion models. Advances in Neural Information Processing Systems, 2022.
  • Hunter (2007) John D Hunter. Matplotlib: A 2d graphics environment. Computing in science & engineering, 9(3):90, 2007.
  • Hutchinson (1989) M.F. Hutchinson. A stochastic estimator of the trace of the influence matrix for Laplacian smoothing splines. 18:1059–1076, 01 1989.
  • Hyvärinen & Dayan (2005) Aapo Hyvärinen and Peter Dayan. Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research, 6(4), 2005.
  • Jones et al. (2014) Eric Jones, Travis Oliphant, and Pearu Peterson. {{\{{SciPy}}\}}: Open source scientific tools for {{\{{Python}}\}}. 2014.
  • Jones et al. (2008) Peter W Jones, Mauro Maggioni, and Raanan Schul. Manifold parametrizations by eigenfunctions of the laplacian and heat kernels. Proceedings of the National Academy of Sciences, 105(6):1803–1808, 2008.
  • Jordahl et al. (2020) Kelsey Jordahl, Joris Van den Bossche, Martin Fleischmann, Jacob Wasserman, James McBride, Jeffrey Gerard, Jeff Tratner, Matthew Perry, Adrian Garcia Badaracco, Carson Farmer, Geir Arne Hjelle, Alan D. Snow, Micah Cochran, Sean Gillies, Lucas Culbertson, Matt Bartos, Nick Eubank, maxalbert, Aleksey Bilogur, Sergio Rey, Christopher Ren, Dani Arribas-Bel, Leah Wasser, Levi John Wolf, Martin Journois, Joshua Wilson, Adam Greenhall, Chris Holdgraf, Filipe, and François Leblanc. geopandas/geopandas: v0.8.1, 2020.
  • Kimmel & Sethian (1998) Ron Kimmel and James A Sethian. Computing geodesic paths on manifolds. Proceedings of the national academy of Sciences, 95(15):8431–8435, 1998.
  • Kloeden et al. (2002) Peter Eris Kloeden, Eckhard Platen, and Henri Schurz. Springer Science & Business Media, 2002.
  • Kluyver et al. (2016) Thomas Kluyver, Benjamin Ragan-Kelley, Fernando Pérez, Brian E Granger, Matthias Bussonnier, Jonathan Frederic, Kyle Kelley, Jessica B Hamrick, Jason Grout, Sylvain Corlay, et al. Jupyter notebooks-a publishing format for reproducible computational workflows. In ELPUB, pp.  87–90, 2016.
  • Leeb et al. (2008) R Leeb, C Brunner, G Müller-Putz, A Schlögl, and GJGUOT Pfurtscheller. Bci competition 2008–graz data set b. Graz University of Technology, Austria, pp.  1–6, 2008.
  • Li et al. (2020) Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, and David Duvenaud. Scalable gradients for stochastic differential equations. In International Conference on Artificial Intelligence and Statistics, pp.  3870–3882. PMLR, 2020.
  • Lipman et al. (2010) Yaron Lipman, Raif M Rustamov, and Thomas A Funkhouser. Biharmonic distance. ACM Transactions on Graphics (TOG), 29(3):1–11, 2010.
  • Lipman et al. (2023) Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, and Matt Le. Flow matching for generative modeling. International Conference on Learning Representations, 2023.
  • Liu et al. (2023) Xingchao Liu, Chengyue Gong, and Qiang Liu. Flow straight and fast: Learning to generate and transfer data with rectified flow. International Conference on Learning Representations, 2023.
  • Lou et al. (2020) Aaron Lou, Derek Lim, Isay Katsman, Leo Huang, Qingxuan Jiang, Ser Nam Lim, and Christopher M De Sa. Neural manifold ordinary differential equations. Advances in Neural Information Processing Systems, 33:17548–17558, 2020.
  • Lovell et al. (2003) Simon C Lovell, Ian W Davis, W Bryan Arendall III, Paul IW De Bakker, J Michael Word, Michael G Prisant, Jane S Richardson, and David C Richardson. Structure validation by cα𝛼\alphaitalic_α geometry: ϕitalic-ϕ\phiitalic_ϕ, ψ𝜓\psiitalic_ψ and cβ𝛽\betaitalic_β deviation. Proteins: Structure, Function, and Bioinformatics, 50(3):437–450, 2003.
  • Mathieu & Nickel (2020) Emile Mathieu and Maximilian Nickel. Riemannian continuous normalizing flows. Advances in Neural Information Processing Systems, 33:2503–2515, 2020.
  • McCann (2001) Robert J McCann. Polar factorization of maps on riemannian manifolds. Geometric & Functional Analysis GAFA, 11(3):589–608, 2001.
  • McKinney (2012) Wes McKinney. Python for data analysis: Data wrangling with Pandas, NumPy, and IPython. " O’Reilly Media, Inc.", 2012.
  • Moakher & Batchelor (2006) Maher Moakher and Philipp G Batchelor. Symmetric positive-definite matrices: From geometry to applications and visualization. Visualization and processing of tensor fields, pp.  285–298, 2006.
  • Murray & Miller (2013) Francis J Murray and Kenneth S Miller. Existence theorems for ordinary differential equations. Courier Corporation, 2013.
  • Murray et al. (2003) Laura JW Murray, W Bryan Arendall III, David C Richardson, and Jane S Richardson. Rna backbone is rotameric. Proceedings of the National Academy of Sciences, 100(24):13904–13909, 2003.
  • Neklyudov et al. (2022) Kirill Neklyudov, Daniel Severo, and Alireza Makhzani. Action matching: A variational method for learning stochastic dynamics from samples. arXiv preprint arXiv:2210.06662, 2022.
  • NOAA (2020a) NOAA. Global significant earthquake database. https://data.nodc.noaa.gov/cgi-bin/iso?id=gov.noaa.ngdc.mgg.hazards:G012153, 2020a. National Geophysical Data Center / World Data Service (NGDC/WDS): NCEI/WDS Global Significant Earthquake Database. NOAA National Centers for Environmental Information.
  • NOAA (2020b) NOAA. Global significant volcanic eruptions database. https://data.nodc.noaa.gov/cgi-bin/iso?id=gov.noaa.ngdc.mgg.hazards:G10147, 2020b. National Geophysical Data Center / World Data Service (NGDC/WDS): NCEI/WDS Global Significant Volcanic Eruptions Database. NOAA National Centers for Environmental Information.
  • Oliphant (2006) Travis E Oliphant. A guide to NumPy, volume 1. Trelgol Publishing USA, 2006.
  • Oliphant (2007) Travis E Oliphant. Python for scientific computing. Computing in Science & Engineering, 9(3):10–20, 2007.
  • Panozzo & Jacobson (2014) Daniele Panozzo and Alec Jacobson. Libigl: A c++ library for geometry processing without a mesh data structure. 2014.
  • Paszke et al. (2019) Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, et al. Pytorch: An imperative style, high-performance deep learning library. In Advances in neural information processing systems, pp.  8026–8037, 2019.
  • Pfau et al. (2018) David Pfau, Stig Petersen, Ashish Agarwal, David GT Barrett, and Kimberly L Stachenfeld. Spectral inference networks: Unifying deep and spectral learning. arXiv preprint arXiv:1806.02215, 2018.
  • Polyak & Juditsky (1992) Boris T. Polyak and Anatoli Juditsky. Acceleration of stochastic approximation by averaging. 1992.
  • Ramachandran et al. (2017) Prajit Ramachandran, Barret Zoph, and Quoc V Le. Searching for activation functions. arXiv preprint arXiv:1710.05941, 2017.
  • Rezende et al. (2020) Danilo Jimenez Rezende, George Papamakarios, Sébastien Racaniere, Michael Albergo, Gurtej Kanwar, Phiala Shanahan, and Kyle Cranmer. Normalizing flows on tori and spheres. In International Conference on Machine Learning, pp.  8083–8092. PMLR, 2020.
  • Rozen et al. (2021) Noam Rozen, Aditya Grover, Maximilian Nickel, and Yaron Lipman. Moser flow: Divergence-based generative modeling on manifolds. Advances in Neural Information Processing Systems, 34:17669–17680, 2021.
  • Skilling (1989) John Skilling. The eigenvalues of mega-dimensional matrices. In Maximum Entropy and Bayesian Methods, pp.  455–466. Springer, 1989.
  • Song et al. (2020a) Yang Song, Sahaj Garg, Jiaxin Shi, and Stefano Ermon. Sliced score matching: A scalable approach to density and score estimation. In Uncertainty in Artificial Intelligence, pp.  574–584. PMLR, 2020a.
  • Song et al. (2020b) Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, 2020b.
  • Surazhsky et al. (2005) Vitaly Surazhsky, Tatiana Surazhsky, Danil Kirsanov, Steven J Gortler, and Hugues Hoppe. Fast exact and approximate geodesics on meshes. ACM transactions on graphics (TOG), 24(3):553–560, 2005.
  • Turk & Levoy (1994) Greg Turk and Marc Levoy. Zippered polygon meshes from range images. In Proceedings of the 21st annual conference on Computer graphics and interactive techniques, pp.  311–318, 1994.
  • Van Der Walt et al. (2011) Stefan Van Der Walt, S Chris Colbert, and Gael Varoquaux. The numpy array: a structure for efficient numerical computation. Computing in Science & Engineering, 13(2):22, 2011.
  • Van Rossum & Drake Jr (1995) Guido Van Rossum and Fred L Drake Jr. Python reference manual. Centrum voor Wiskunde en Informatica Amsterdam, 1995.
  • Villani (2009) Cédric Villani. Optimal transport: old and new, volume 338. Springer, 2009.
  • Vincent (2011) Pascal Vincent. A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674, 2011.
  • Waskom et al. (2018) Michael Waskom, Olga Botvinnik, Drew O’Kane, Paul Hobson, Joel Ostblom, Saulius Lukauskas, David C Gemperline, Tom Augspurger, Yaroslav Halchenko, John B. Cole, Jordi Warmenhoven, Julian de Ruiter, Cameron Pye, Stephan Hoyer, Jake Vanderplas, Santi Villalba, Gero Kunter, Eric Quintero, Pete Bachant, Marcel Martin, Kyle Meyer, Alistair Miles, Yoav Ram, Thomas Brunner, Tal Yarkoni, Mike Lee Williams, Constantine Evans, Clark Fitzgerald, Brian, and Adel Qalieh. mwaskom/seaborn: v0.9.0 (july 2018), July 2018. URL https://doi.org/10.5281/zenodo.1313201.
  • Yadan (2019) Omry Yadan. Hydra - a framework for elegantly configuring complex applications. Github, 2019. URL https://github.com/facebookresearch/hydra.

Appendix A Conditional Flow Matching on manifolds

We provide the necessary derivations and proofs for the Conditional Flow Matching over a Riemannian manifolds; the proofs and derivations from Lipman et al. (2023) are followed "as-is", with the necessary adaptation to the Riemannian setting.

Assumptions. We will use notations and setup from Section 2. Let p(|x1):[0,1]𝒫p(\cdot|x_{1}):[0,1]\rightarrow\mathcal{P}italic_p ( ⋅ | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) : [ 0 , 1 ] → caligraphic_P be a (conditional) probability path sufficiently smooth with integrable derivatives, strictly positive pt(x|x1)>0subscript𝑝𝑡conditional𝑥subscript𝑥10p_{t}(x|x_{1})>0italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) > 0, and p0(x|x1)=psubscript𝑝0conditional𝑥subscript𝑥1𝑝p_{0}(x|x_{1})=pitalic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_p, where p𝒫𝑝𝒫p\in\mathcal{P}italic_p ∈ caligraphic_P is our source density. Let u(|x1)𝔘u(\cdot|x_{1})\in\mathfrak{U}italic_u ( ⋅ | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∈ fraktur_U be a (conditional) time-dependent vector field, sufficiently smooth with integrable derivatives and such that

01ut(x|x1)gpt(x|x1)dvolxdt<.\int_{0}^{1}\int_{\mathcal{M}}\left\|u_{t}(x|x_{1})\right\|_{g}p_{t}(x|x_{1})d% \mathrm{vol}_{x}dt<\infty.∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ∥ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d italic_t < ∞ .

Further assume ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) generates pt(x|x1)subscript𝑝𝑡conditional𝑥subscript𝑥1p_{t}(x|x_{1})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) from p𝑝pitalic_p in the sense of equation 2, i.e., if we denote by ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) the solution to the ODE (equation 1):

ddtψt(x|x1)𝑑𝑑𝑡subscript𝜓𝑡conditional𝑥subscript𝑥1\displaystyle\frac{d}{dt}\psi_{t}(x|x_{1})divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) =ut(ψt(x|x1)|x1)absentsubscript𝑢𝑡conditionalsubscript𝜓𝑡conditional𝑥subscript𝑥1subscript𝑥1\displaystyle=u_{t}(\psi_{t}(x|x_{1})|x_{1})= italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (18)
ψ0(x|x1)subscript𝜓0conditional𝑥subscript𝑥1\displaystyle\psi_{0}(x|x_{1})italic_ψ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) =xabsent𝑥\displaystyle=x= italic_x (19)

then

pt(|x1)=[ψt(|x1)]#p.p_{t}(\cdot|x_{1})=\left[\psi_{t}(\cdot|x_{1})\right]_{\#}p.italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = [ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p . (20)

Proof of the marginal VF formula, equation 7. First, the Mass Conservation Formula Theorem (see, e.g., Villani (2009)) implies that pt(x|x1)subscript𝑝𝑡conditional𝑥subscript𝑥1p_{t}(x|x_{1})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) satisfy

ddtpt(x|x1)+divg(pt(x|x1)ut(x|x1))=0𝑑𝑑𝑡subscript𝑝𝑡conditional𝑥subscript𝑥1subscriptdiv𝑔subscript𝑝𝑡conditional𝑥subscript𝑥1subscript𝑢𝑡conditional𝑥subscript𝑥10\frac{d}{dt}p_{t}(x|x_{1})+\mathrm{div}_{g}(p_{t}(x|x_{1})u_{t}(x|x_{1}))=0divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) = 0 (21)

where divgsubscriptdiv𝑔\mathrm{div}_{g}roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT is the Riemannian divergence with metric g𝑔gitalic_g.

Next, we differentiate the marginal pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) w.r.t. t𝑡titalic_t:

ddtpt(x)𝑑𝑑𝑡subscript𝑝𝑡𝑥\displaystyle\frac{d}{dt}p_{t}(x)divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) =ddtpt(x|x1)q(x1)𝑑volx1absentsubscript𝑑𝑑𝑡subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1differential-dsubscriptvolsubscript𝑥1\displaystyle=\int_{\mathcal{M}}\frac{d}{dt}p_{t}(x|x_{1})q(x_{1})d\mathrm{vol% }_{x_{1}}= ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
=divg[ut(x|x1)pt(x|x1)q(x1)𝑑volx1]absentsubscriptdiv𝑔delimited-[]subscriptsubscript𝑢𝑡conditional𝑥subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1differential-dsubscriptvolsubscript𝑥1\displaystyle=-\mathrm{div}_{g}\left[\int_{\mathcal{M}}u_{t}(x|x_{1})p_{t}(x|x% _{1})q(x_{1})d\mathrm{vol}_{x_{1}}\right]= - roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT [ ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ]
=divg[pt(x)ut(x|x1)pt(x|x1)q(x1)pt(x)𝑑volx1]absentsubscriptdiv𝑔delimited-[]subscript𝑝𝑡𝑥subscriptsubscript𝑢𝑡conditional𝑥subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1subscript𝑝𝑡𝑥differential-dsubscriptvolsubscript𝑥1\displaystyle=-\mathrm{div}_{g}\left[p_{t}(x)\int_{\mathcal{M}}u_{t}(x|x_{1})% \frac{p_{t}(x|x_{1})q(x_{1})}{p_{t}(x)}d\mathrm{vol}_{x_{1}}\right]= - roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT [ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ]
=divg[pt(x)ut(x)]absentsubscriptdiv𝑔delimited-[]subscript𝑝𝑡𝑥subscript𝑢𝑡𝑥\displaystyle=-\mathrm{div}_{g}\left[p_{t}(x)u_{t}(x)\right]= - roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT [ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ]

where in the first and second equalities we changed the order of differentiation and integration, and in the second equality we used the mass conservation formula for ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). In the previous to last equality we multiplied and divided by pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ). In the last equality we defined the marginal vector field utsubscript𝑢𝑡u_{t}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as in equation 7.

RCFM loss equivalent to RFM loss. We will now show the equivalence of the RCFM loss (equation 8) and the RFM loss (equation 3). First note the losses expand as follows:

RFM(θ)=𝔼t,pt(x)vt(x)ut(x)g2=𝔼t,pt(x)vt(x)g22vt(x),ut(x)g+ut(x)g2subscriptRFM𝜃subscript𝔼𝑡subscript𝑝𝑡𝑥superscriptsubscriptnormsubscript𝑣𝑡𝑥subscript𝑢𝑡𝑥𝑔2subscript𝔼𝑡subscript𝑝𝑡𝑥superscriptsubscriptnormsubscript𝑣𝑡𝑥𝑔22subscriptsubscript𝑣𝑡𝑥subscript𝑢𝑡𝑥𝑔superscriptsubscriptnormsubscript𝑢𝑡𝑥𝑔2\displaystyle{\mathcal{L}}_{\scriptscriptstyle\text{RFM}}(\theta)=\mathbb{E}_{% t,p_{t}(x)}\left\|v_{t}(x)-u_{t}(x)\right\|_{g}^{2}=\mathbb{E}_{t,p_{t}(x)}% \left\|v_{t}(x)\right\|_{g}^{2}-2\left\langle v_{t}(x),u_{t}(x)\right\rangle_{% g}+\left\|u_{t}(x)\right\|_{g}^{2}caligraphic_L start_POSTSUBSCRIPT RFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_t , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_t , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT + ∥ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
RCFM(θ)=𝔼t,q(x1)pt(x|x1)vt(x)ut(x|x1)g2=𝔼t,q(x1)pt(x|x1)vt(x)g22vt(x),ut(x|x1)g+ut(x|x1)g2\displaystyle{\mathcal{L}}_{\scriptscriptstyle\text{RCFM}}(\theta)=\mathbb{E}_% {\begin{subarray}{c}t,q(x_{1})\\ p_{t}(x|x_{1})\end{subarray}}\left\|v_{t}(x)-u_{t}(x|x_{1})\right\|_{g}^{2}=% \mathbb{E}_{\begin{subarray}{c}t,q(x_{1})\\ p_{t}(x|x_{1})\end{subarray}}\left\|v_{t}(x)\right\|_{g}^{2}-2\left\langle v_{% t}(x),u_{t}(x|x_{1})\right\rangle_{g}+\left\|u_{t}(x|x_{1})\right\|_{g}^{2}caligraphic_L start_POSTSUBSCRIPT RCFM end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT + ∥ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Second, note that

𝔼t,q(x1),pt(x|x1)vtg2subscript𝔼𝑡𝑞subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1superscriptsubscriptnormsubscript𝑣𝑡𝑔2\displaystyle\mathbb{E}_{t,q(x_{1}),p_{t}(x|x_{1})}\left\|v_{t}\right\|_{g}^{2}blackboard_E start_POSTSUBSCRIPT italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT =01vt(x)g2pt(x|x1)q(x1)𝑑volx𝑑volx1𝑑tabsentsuperscriptsubscript01subscriptsuperscriptsubscriptnormsubscript𝑣𝑡𝑥𝑔2subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1differential-dsubscriptvol𝑥differential-dsubscriptvolsubscript𝑥1differential-d𝑡\displaystyle=\int_{0}^{1}\int_{\mathcal{M}}\left\|v_{t}(x)\right\|_{g}^{2}p_{% t}(x|x_{1})q(x_{1})d\mathrm{vol}_{x}d\mathrm{vol}_{x_{1}}dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_d italic_t
=01vt(x)g2pt(x)𝑑volx𝑑tabsentsuperscriptsubscript01subscriptsuperscriptsubscriptnormsubscript𝑣𝑡𝑥𝑔2subscript𝑝𝑡𝑥differential-dsubscriptvol𝑥differential-d𝑡\displaystyle=\int_{0}^{1}\int_{\mathcal{M}}\left\|v_{t}(x)\right\|_{g}^{2}p_{% t}(x)d\mathrm{vol}_{x}dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d italic_t
=𝔼t,pt(x)vtg2absentsubscript𝔼𝑡subscript𝑝𝑡𝑥superscriptsubscriptnormsubscript𝑣𝑡𝑔2\displaystyle=\mathbb{E}_{t,p_{t}(x)}\left\|v_{t}\right\|_{g}^{2}= blackboard_E start_POSTSUBSCRIPT italic_t , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Lastly,

𝔼t,q(x1),pt(x|x1)vt(x),ut(x|x1)gsubscript𝔼𝑡𝑞subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1subscriptsubscript𝑣𝑡𝑥subscript𝑢𝑡conditional𝑥subscript𝑥1𝑔\displaystyle\mathbb{E}_{t,q(x_{1}),p_{t}(x|x_{1})}\left\langle v_{t}(x),u_{t}% (x|x_{1})\right\rangle_{g}blackboard_E start_POSTSUBSCRIPT italic_t , italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT =01vt(x),ut(x|x1)gpt(x|x1)q(x1)𝑑volx𝑑volx1𝑑tabsentsuperscriptsubscript01subscriptsubscriptsubscriptsubscript𝑣𝑡𝑥subscript𝑢𝑡conditional𝑥subscript𝑥1𝑔subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1differential-dsubscriptvol𝑥differential-dsubscriptvolsubscript𝑥1differential-d𝑡\displaystyle=\int_{0}^{1}\int_{\mathcal{M}}\int_{\mathcal{M}}\left\langle v_{% t}(x),u_{t}(x|x_{1})\right\rangle_{g}p_{t}(x|x_{1})q(x_{1})d\mathrm{vol}_{x}d% \mathrm{vol}_{x_{1}}dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_d italic_t
=01vt(x),ut(x|x1)pt(x|x1)q(x1)𝑑volx1g𝑑volx𝑑tabsentsuperscriptsubscript01subscriptsubscriptsubscript𝑣𝑡𝑥subscriptsubscript𝑢𝑡conditional𝑥subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1differential-dsubscriptvolsubscript𝑥1𝑔differential-dsubscriptvol𝑥differential-d𝑡\displaystyle=\int_{0}^{1}\int_{\mathcal{M}}\left\langle v_{t}(x),\int_{% \mathcal{M}}u_{t}(x|x_{1})p_{t}(x|x_{1})q(x_{1})d\mathrm{vol}_{x_{1}}\right% \rangle_{g}d\mathrm{vol}_{x}\,dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d italic_t
=01vt(x),ut(x|x1)pt(x|x1)q(x1)pt(x)𝑑volx1gpt(x)𝑑volx𝑑tabsentsuperscriptsubscript01subscriptsubscriptsubscript𝑣𝑡𝑥subscriptsubscript𝑢𝑡conditional𝑥subscript𝑥1subscript𝑝𝑡conditional𝑥subscript𝑥1𝑞subscript𝑥1subscript𝑝𝑡𝑥differential-dsubscriptvolsubscript𝑥1𝑔subscript𝑝𝑡𝑥differential-dsubscriptvol𝑥differential-d𝑡\displaystyle=\int_{0}^{1}\int_{\mathcal{M}}\left\langle v_{t}(x),\int_{% \mathcal{M}}u_{t}(x|x_{1})\frac{p_{t}(x|x_{1})q(x_{1})}{p_{t}(x)}d\mathrm{vol}% _{x_{1}}\right\rangle_{g}p_{t}(x)d\mathrm{vol}_{x}\,dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d italic_t
=01vt(x),ut(x)gpt(x)𝑑volx𝑑tabsentsuperscriptsubscript01subscriptsubscriptsubscript𝑣𝑡𝑥subscript𝑢𝑡𝑥𝑔subscript𝑝𝑡𝑥differential-dsubscriptvol𝑥differential-d𝑡\displaystyle=\int_{0}^{1}\int_{\mathcal{M}}\left\langle v_{t}(x),u_{t}(x)% \right\rangle_{g}p_{t}(x)d\mathrm{vol}_{x}\,dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d italic_t
=𝔼t,pt(x)vt(x),ut(x)gabsentsubscript𝔼𝑡subscript𝑝𝑡𝑥subscriptsubscript𝑣𝑡𝑥subscript𝑢𝑡𝑥𝑔\displaystyle=\mathbb{E}_{t,p_{t}(x)}\left\langle v_{t}(x),u_{t}(x)\right% \rangle_{g}= blackboard_E start_POSTSUBSCRIPT italic_t , italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_POSTSUBSCRIPT ⟨ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT

We got that RCFM(θ)subscriptRCFM𝜃{\mathcal{L}}_{\scriptscriptstyle\text{RCFM}}(\theta)caligraphic_L start_POSTSUBSCRIPT RCFM end_POSTSUBSCRIPT ( italic_θ ) and RFM(θ)subscriptRFM𝜃{\mathcal{L}}_{\scriptscriptstyle\text{RFM}}(\theta)caligraphic_L start_POSTSUBSCRIPT RFM end_POSTSUBSCRIPT ( italic_θ ) differ by a constant,

const=01ut(x)g2pt(x)dvolxdt01ut(x|x1)g2pt(x|x1)q(x1)dvolxdvolx1dt\text{const}=\int_{0}^{1}\int_{\mathcal{M}}\left\|u_{t}(x)\right\|_{g}^{2}p_{t% }(x)d\mathrm{vol}_{x}\,dt-\int_{0}^{1}\int_{\mathcal{M}}\left\|u_{t}(x|x_{1})% \right\|_{g}^{2}p_{t}(x|x_{1})q(x_{1})d\mathrm{vol}_{x}d\mathrm{vol}_{x_{1}}\,dtconst = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ∥ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d italic_t - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ∥ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_d roman_vol start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_d italic_t

that does not depend on θ𝜃\thetaitalic_θ.

Appendix B Proof of Theorem 3.1

Theorem 3.1.

The flow ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined by the vector field ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) in equation 13 satisfies equation 12, and therefore also equation 11. Conversely, out of all conditional flows ψ~t(x|x1)subscript~𝜓𝑡conditional𝑥subscript𝑥1\tilde{\psi}_{t}(x|x_{1})over~ start_ARG italic_ψ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined by a vector fields u~(x|x1)~𝑢conditional𝑥subscript𝑥1\tilde{u}(x|x_{1})over~ start_ARG italic_u end_ARG ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) that satisfy equation 12, this ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is of minimal norm.

Proof.

Let xt=ψt(x|x1)subscript𝑥𝑡subscript𝜓𝑡conditional𝑥subscript𝑥1x_{t}=\psi_{t}(x|x_{1})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) be the flow defined by equation 13 in the sense of equation 1. Differentiating the time-dependent function d(xt,x1)dsubscript𝑥𝑡subscript𝑥1\textrm{d}(x_{t},x_{1})d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) w.r.t. time gives

ddtd(xt,x1)𝑑𝑑𝑡dsubscript𝑥𝑡subscript𝑥1\displaystyle\frac{d}{dt}\textrm{d}(x_{t},x_{1})divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) =d(xt,x1),x˙tg=d(xt,x1),ut(xt|x1)g=dlogκ(t)dtd(xt,x1)absentsubscriptdsubscript𝑥𝑡subscript𝑥1subscript˙𝑥𝑡𝑔subscriptdsubscript𝑥𝑡subscript𝑥1subscript𝑢𝑡conditionalsubscript𝑥𝑡subscript𝑥1𝑔𝑑𝜅𝑡𝑑𝑡dsubscript𝑥𝑡subscript𝑥1\displaystyle=\left\langle\nabla\textrm{d}(x_{t},x_{1}),\dot{x}_{t}\right% \rangle_{g}=\left\langle\nabla\textrm{d}(x_{t},x_{1}),u_{t}(x_{t}|x_{1})\right% \rangle_{g}=\frac{d\log\kappa(t)}{dt}\textrm{d}(x_{t},x_{1})= ⟨ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = ⟨ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = divide start_ARG italic_d roman_log italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (22)

This shows that the function a(t)=d(xt,x1)𝑎𝑡dsubscript𝑥𝑡subscript𝑥1a(t)=\textrm{d}(x_{t},x_{1})italic_a ( italic_t ) = d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) satisfies the ODE

ddta(t)=dlogκ(t)dta(t),𝑑𝑑𝑡𝑎𝑡𝑑𝜅𝑡𝑑𝑡𝑎𝑡\frac{d}{dt}a(t)=\frac{d\log\kappa(t)}{dt}a(t),divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_a ( italic_t ) = divide start_ARG italic_d roman_log italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG italic_a ( italic_t ) ,

with the initial condition a(0)=d(x,x1)𝑎0d𝑥subscript𝑥1a(0)=\textrm{d}(x,x_{1})italic_a ( 0 ) = d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). General solutions to this ODE are of the form a(t)=cκ(t)𝑎𝑡𝑐𝜅𝑡a(t)=c\kappa(t)italic_a ( italic_t ) = italic_c italic_κ ( italic_t ), where c>0𝑐0c>0italic_c > 0 is a constant set by the initial conditions. This can be verified by substitution. The constant c𝑐citalic_c is set by the initial condition,

d(x,x1)=a(0)=cκ(0)=c.d𝑥subscript𝑥1𝑎0𝑐𝜅0𝑐\textrm{d}(x,x_{1})=a(0)=c\kappa(0)=c.d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_a ( 0 ) = italic_c italic_κ ( 0 ) = italic_c .

This gives a(t)=d(x,x1)κ(t)𝑎𝑡d𝑥subscript𝑥1𝜅𝑡a(t)=\textrm{d}(x,x_{1})\kappa(t)italic_a ( italic_t ) = d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_κ ( italic_t ) as the solution. Due to uniqueness of ODE solutions we get that equation 12 holds.

Conversely, consider xt=ψt(x|x1)subscript𝑥𝑡subscript𝜓𝑡conditional𝑥subscript𝑥1x_{t}=\psi_{t}(x|x_{1})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) satisfying equation 12. Differentiating both sides of this equation w.r.t. t𝑡titalic_t and then using equation 12 again we get

d(xt,x1),x˙tg=dκ(t)dtd(x,x1)=dκ(t)dt1κ(t)d(xt,x1)=dlogκ(t)dtd(xt,x1).subscriptdsubscript𝑥𝑡subscript𝑥1subscript˙𝑥𝑡𝑔𝑑𝜅𝑡𝑑𝑡d𝑥subscript𝑥1𝑑𝜅𝑡𝑑𝑡1𝜅𝑡dsubscript𝑥𝑡subscript𝑥1𝑑𝜅𝑡𝑑𝑡dsubscript𝑥𝑡subscript𝑥1\displaystyle\left\langle\nabla\textrm{d}(x_{t},x_{1}),\dot{x}_{t}\right% \rangle_{g}=\frac{d\kappa(t)}{dt}\textrm{d}(x,x_{1})=\frac{d\kappa(t)}{dt}% \frac{1}{\kappa(t)}\textrm{d}(x_{t},x_{1})=\frac{d\log\kappa(t)}{dt}\textrm{d}% (x_{t},x_{1}).⟨ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = divide start_ARG italic_d italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = divide start_ARG italic_d italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG divide start_ARG 1 end_ARG start_ARG italic_κ ( italic_t ) end_ARG d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = divide start_ARG italic_d roman_log italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

If we let ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) denote the VF defining the diffeomorphism ψt(x|x1)subscript𝜓𝑡conditional𝑥subscript𝑥1\psi_{t}(x|x_{1})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) in the sense of equation 1 then the last equation takes the form

d(xt,x1),ut(xt|x1)g=dlogκ(t)dtd(xt,x1).subscriptdsubscript𝑥𝑡subscript𝑥1subscript𝑢𝑡conditionalsubscript𝑥𝑡subscript𝑥1𝑔𝑑𝜅𝑡𝑑𝑡dsubscript𝑥𝑡subscript𝑥1\displaystyle\left\langle\nabla\textrm{d}(x_{t},x_{1}),u_{t}(x_{t}|x_{1})% \right\rangle_{g}=\frac{d\log\kappa(t)}{dt}\textrm{d}(x_{t},x_{1}).⟨ ∇ d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = divide start_ARG italic_d roman_log italic_κ ( italic_t ) end_ARG start_ARG italic_d italic_t end_ARG d ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) . (23)

This equation provides an under-determined linear system for ut(y|x1)Tysubscript𝑢𝑡conditional𝑦subscript𝑥1subscript𝑇𝑦u_{t}(y|x_{1})\in T_{y}{\mathcal{M}}italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∈ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_M at every point y=xt𝑦subscript𝑥𝑡y=x_{t}italic_y = italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with non-zero probability, which in our case is all y𝑦y\in{\mathcal{M}}italic_y ∈ caligraphic_M as we assume pt(y|x1)>0subscript𝑝𝑡conditional𝑦subscript𝑥10p_{t}(y|x_{1})>0italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) > 0 for all y𝑦y\in{\mathcal{M}}italic_y ∈ caligraphic_M. As can be seen in equation 22, ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined in equation 13 is also satisfying this equation. Further note, that since ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined in equation 13 satisfies ut(x|x1)d(x,x1)conditionalsubscript𝑢𝑡conditional𝑥subscript𝑥1d𝑥subscript𝑥1u_{t}(x|x_{1})\parallel\nabla\textrm{d}(x,x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ ∇ d ( italic_x , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (proportional) it is the minimal norm solution to the linear system in equation 23. ∎

Appendix C Proof of Proposition 3.2

Proposition 3.2.

Consider a complete, connected smooth Riemannian manifold (,g)𝑔({\mathcal{M}},g)( caligraphic_M , italic_g ) with geodesic distance dg(x,y)subscriptd𝑔𝑥𝑦\textrm{d}_{g}(x,y)d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ). In case d(x,y)=dg(x,y)d𝑥𝑦subscriptd𝑔𝑥𝑦\textrm{d}(x,y)=\textrm{d}_{g}(x,y)d ( italic_x , italic_y ) = d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) then xt=ψt(x0|x1)subscript𝑥𝑡subscript𝜓𝑡conditionalsubscript𝑥0subscript𝑥1x_{t}=\psi_{t}(x_{0}|x_{1})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) defined by the conditional VF in equation 13 with the scheduler κ(t)=1t𝜅𝑡1𝑡\kappa(t)=1-titalic_κ ( italic_t ) = 1 - italic_t is a geodesic connecting x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Proof.

First, note that by definition ψ0(x0|x1)=x0subscript𝜓0conditionalsubscript𝑥0subscript𝑥1subscript𝑥0\psi_{0}(x_{0}|x_{1})=x_{0}italic_ψ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, and ψ1(x0|x1)=x1subscript𝜓1conditionalsubscript𝑥0subscript𝑥1subscript𝑥1\psi_{1}(x_{0}|x_{1})=x_{1}italic_ψ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Second, from Proposition 6 in McCann (2001) we have that

x12dg(x,y)2=logx(y)subscript𝑥12subscriptd𝑔superscript𝑥𝑦2subscript𝑥𝑦\displaystyle\nabla_{x}\frac{1}{2}\textrm{d}_{g}(x,y)^{2}=-\log_{x}(y)∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = - roman_log start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_y )

where log\logroman_log is the Riemannian logarithm map. From the chain rule we have

x12dg(x,y)2=dg(x,y)xdg(x,y)subscript𝑥12subscriptd𝑔superscript𝑥𝑦2subscriptd𝑔𝑥𝑦subscript𝑥subscriptd𝑔𝑥𝑦\displaystyle\nabla_{x}\frac{1}{2}\textrm{d}_{g}(x,y)^{2}=\textrm{d}_{g}(x,y)% \nabla_{x}\textrm{d}_{g}(x,y)∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y )

Since the logarithm map satisfies logx(y)g=dg(x,y)subscriptnormsubscript𝑥𝑦𝑔subscriptd𝑔𝑥𝑦\left\|\log_{x}(y)\right\|_{g}=\textrm{d}_{g}(x,y)∥ roman_log start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_y ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) we have that

dg(x,y)g=1.subscriptnormsubscriptd𝑔𝑥𝑦𝑔1\left\|\nabla\textrm{d}_{g}(x,y)\right\|_{g}=1.∥ ∇ d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x , italic_y ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = 1 . (24)

Now, computing the length of the curve xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT we get

01x˙tg𝑑tsuperscriptsubscript01subscriptnormsubscript˙𝑥𝑡𝑔differential-d𝑡\displaystyle\int_{0}^{1}\left\|\dot{x}_{t}\right\|_{g}dt∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∥ over˙ start_ARG italic_x end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_d italic_t =01ut(xt|x1)gdt\displaystyle=\int_{0}^{1}\left\|u_{t}(x_{t}|x_{1})\right\|_{g}dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∥ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_d italic_t
=01dg(xt,x1)1tdg(xt,x1)dg(xt,x1)g2g𝑑tabsentsuperscriptsubscript01subscriptnormsubscriptd𝑔subscript𝑥𝑡subscript𝑥11𝑡subscriptd𝑔subscript𝑥𝑡subscript𝑥1superscriptsubscriptnormsubscriptd𝑔subscript𝑥𝑡subscript𝑥1𝑔2𝑔differential-d𝑡\displaystyle=\int_{0}^{1}\left\|-\frac{\textrm{d}_{g}(x_{t},x_{1})}{1-t}\frac% {\nabla\textrm{d}_{g}(x_{t},x_{1})}{\left\|\nabla\textrm{d}_{g}(x_{t},x_{1})% \right\|_{g}^{2}}\right\|_{g}dt= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∥ - divide start_ARG d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG 1 - italic_t end_ARG divide start_ARG ∇ d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_d italic_t
=dg(x0,x1)01dg(xt,x1)dg(xt,x1)g2g𝑑tabsentsubscriptd𝑔subscript𝑥0subscript𝑥1superscriptsubscript01subscriptnormsubscriptd𝑔subscript𝑥𝑡subscript𝑥1superscriptsubscriptnormsubscriptd𝑔subscript𝑥𝑡subscript𝑥1𝑔2𝑔differential-d𝑡\displaystyle=\textrm{d}_{g}(x_{0},x_{1})\int_{0}^{1}\left\|\frac{\nabla% \textrm{d}_{g}(x_{t},x_{1})}{\left\|\nabla\textrm{d}_{g}(x_{t},x_{1})\right\|_% {g}^{2}}\right\|_{g}dt= d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∥ divide start_ARG ∇ d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_d italic_t
=dg(x0,x1)01𝑑tabsentsubscriptd𝑔subscript𝑥0subscript𝑥1superscriptsubscript01differential-d𝑡\displaystyle=\textrm{d}_{g}(x_{0},x_{1})\int_{0}^{1}dt= d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT italic_d italic_t
=dg(x0,x1)absentsubscriptd𝑔subscript𝑥0subscript𝑥1\displaystyle=\textrm{d}_{g}(x_{0},x_{1})= d start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )

where in the second equality we used the definition of the conditional VF (equation 13) with κ(t)=1t𝜅𝑡1𝑡\kappa(t)=1-titalic_κ ( italic_t ) = 1 - italic_t, in the third equality we used Theorem 3.1 and equation 12, and in the fourth equality we used equation 24. Since xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT realizes a minimum of the length function, it is a geodesic. ∎

Appendix D Algorithmic comparison to Riemannian diffusion models

Algorithm 2 Riemannian Diffusion Models
0:   base distribution p(xT)𝑝subscript𝑥𝑇p(x_{T})italic_p ( italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ), target q(x0)𝑞subscript𝑥0q(x_{0})italic_q ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
  Initialize parameters θ𝜃\thetaitalic_θ of stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
  while not converged do
     sample time t𝒰(0,T)similar-to𝑡𝒰0𝑇t\sim{\mathcal{U}}(0,{\color[rgb]{1,0,0}T})italic_t ∼ caligraphic_U ( 0 , italic_T )
     sample training example x0q(x0)similar-tosubscript𝑥0𝑞subscript𝑥0x_{0}\sim q(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_q ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
     
     
     
     % simulate Geometric Random Walk
     xt=solve_SDE([0,t],x0)subscript𝑥𝑡solve_SDE0𝑡subscript𝑥0x_{t}={\color[rgb]{1,0,0}\texttt{solve\_SDE}}([0,t],x_{0})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = solve_SDE ( [ 0 , italic_t ] , italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
     
     
     
     if denoising score matching then
        % approximate conditional score
        logpt(x|x0){eig-expansionVarhadansubscript𝑝𝑡conditional𝑥subscript𝑥0caseseig-expansion𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒Varhadan𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒{\color[rgb]{1,0,0}\nabla\log p_{t}(x|x_{0})\approx\begin{cases}\text{eig-% expansion}\\ \text{Varhadan}\end{cases}}∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≈ { start_ROW start_CELL eig-expansion end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL Varhadan end_CELL start_CELL end_CELL end_ROW
        (θ)=st(xt;θ)logpt(x|x0)g2\ell(\theta)=\left\|s_{t}(x_{t};\theta)-\nabla\log p_{t}(x|x_{0})\right\|_{g}^% {2}roman_ℓ ( italic_θ ) = ∥ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
     else if implicit score matching then
        % estimate Riemmanian divergence
        sample ε𝒩(0,I)similar-to𝜀𝒩0𝐼\varepsilon\sim\mathcal{N}(0,I)italic_ε ∼ caligraphic_N ( 0 , italic_I )
        divgstε𝖳stxtε+12st𝖳logdetg(xt)xtsubscriptdiv𝑔subscript𝑠𝑡superscript𝜀𝖳subscript𝑠𝑡subscript𝑥𝑡𝜀12superscriptsubscript𝑠𝑡𝖳𝑔subscript𝑥𝑡subscript𝑥𝑡{\color[rgb]{1,0,0}\text{div}_{g}s_{t}\approx\varepsilon^{\mkern-1.5mu\mathsf{% T}}{}\frac{\partial s_{t}}{\partial x_{t}}\varepsilon+\tfrac{1}{2}s_{t}^{% \mkern-1.5mu\mathsf{T}}{}\frac{\partial\log\det g(x_{t})}{\partial x_{t}}}div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ italic_ε start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT divide start_ARG ∂ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG italic_ε + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT divide start_ARG ∂ roman_log roman_det italic_g ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG
        (θ)=12st(xt;θ)g2+divgst𝜃12superscriptsubscriptnormsubscript𝑠𝑡subscript𝑥𝑡𝜃𝑔2subscriptdiv𝑔subscript𝑠𝑡\ell(\theta)=\tfrac{1}{2}\left\|s_{t}(x_{t};\theta)\right\|_{g}^{2}+\text{div}% _{g}s_{t}roman_ℓ ( italic_θ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
     end if
     
     θ=optimizer_step((θ))𝜃optimizer_step𝜃\theta=\texttt{optimizer\_step}(\ell(\theta))italic_θ = optimizer_step ( roman_ℓ ( italic_θ ) )
  end while
Algorithm 3 Riemannian Flow Matching
0:   base distribution p(x0)𝑝subscript𝑥0p(x_{0})italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), target q(x1)𝑞subscript𝑥1q(x_{1})italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
  Initialize parameters θ𝜃\thetaitalic_θ of vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
  while not converged do
     sample time t𝒰(0,1)similar-to𝑡𝒰01t\sim{\mathcal{U}}(0,1)italic_t ∼ caligraphic_U ( 0 , 1 )
     sample training example x1q(x1)similar-tosubscript𝑥1𝑞subscript𝑥1x_{1}\sim q(x_{1})italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
     sample noise x0p(x0)similar-tosubscript𝑥0𝑝subscript𝑥0x_{0}\sim p(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
     
     if simple geometry then
        xt=𝚎𝚡𝚙x0(κ(t)𝚕𝚘𝚐x0(x1))subscript𝑥𝑡subscript𝚎𝚡𝚙subscript𝑥0𝜅𝑡subscript𝚕𝚘𝚐subscript𝑥0subscript𝑥1x_{t}=\texttt{exp}_{x_{0}}(\kappa(t)\texttt{log}_{x_{0}}(x_{1}))italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = exp start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_κ ( italic_t ) log start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) )
     else if general geometry then
        xt=solve_ODE([0,t],x0,ut(x|x1))subscript𝑥𝑡solve_ODE0𝑡subscript𝑥0subscript𝑢𝑡conditional𝑥subscript𝑥1x_{t}={\color[rgb]{1,0,0}\texttt{solve\_ODE}}([0,t],x_{0},u_{t}(x|x_{1}))italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = solve_ODE ( [ 0 , italic_t ] , italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) )
     end if
     
     
     
     
     
     
     % closed-form regression target ut(xt|x1)subscript𝑢𝑡conditionalsubscript𝑥𝑡subscript𝑥1u_{t}(x_{t}|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
     (θ)=vt(xt;θ)ut(xt|x1)g2\ell(\theta)=\left\|v_{t}(x_{t};\theta)-u_{t}(x_{t}|x_{1})\right\|_{g}^{2}roman_ℓ ( italic_θ ) = ∥ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ ) - italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
     
     
     
     
     
     
     θ=optimizer_step((θ))𝜃optimizer_step𝜃\theta=\texttt{optimizer\_step}(\ell(\theta))italic_θ = optimizer_step ( roman_ℓ ( italic_θ ) )
  end while
Figure 7: Algorithmic comparison between Riemannian Diffusion Models (De Bortoli et al., 2022; Huang et al., 2022) and our Riemannian Flow Matching. Note time is reversed between these formulations. In red, we denote expensive computational aspects (sequential simulation during training), biased approximations (for the score function), and stochastic estimation (for divergence) that may not scale well. Also note that Geometric Random Walk does not converge to the stationary prior distribution unless simulated for an infinite amount of time, in practice requiring tuning T𝑇Titalic_T as a hyperparameter depending on the manifold. On simple manifolds, Riemannian Flow Matching bypasses all computational inconveniences and in particular is completely simulation-free.

Appendix E Limitations

As can be seen from Figure 7, our method still requires simulation of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT on general manifolds. This sequential process can be time consuming, and a more parallel or simulation-free approach to constructing xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT would be more favorable. Furthermore, the spectral distances require eigenfunction solvers which may be computationally expensive on complex manifolds. Using approximate methods such as neural eigenfunctions (Pfau et al., 2018; Deng et al., 2022) may be a possibility. One major advantage of our premetric formulation is that these eigenfunctions need not be perfectly solved in order to satisfy the relatively simple properties of our premetric.

Appendix F Additional figures

Volcano
Refer to caption
Earthquake
Refer to caption
Flood
Refer to caption
Fire
Refer to caption
Figure 8: Data samples and Ramachandran plots depicting log likelihood for protein datasets.
Refer to caption
General
Refer to caption
Glycine
Refer to caption
Proline
Refer to caption
Pre-Pro
Figure 9: Data samples and Ramachandran plots depicting log likelihood for protein datasets.

Appendix G Additional Discussion

G.1 On the use of approximate spectral distances as the premetric

Computation cost. The smallest k𝑘kitalic_k eigenvalues and their eigenfunctions need only be computed once as a pre-processing step. On manifolds represented as discrete triangular meshes, this step can be done in a matter of seconds. Afterwards, spectral distances can be computed very efficiently for all pairs of points. Note however, that training with RCFM does still require simulating for xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, but as the vector fields (equation 13) do not contain neural networks, the flows can be solved efficiently in practice. This results in a similar cost to diffusion-based methods (Huang et al., 2022; De Bortoli et al., 2022) that require simulation even for simple manifolds.

Figure 10: A visualization of the approximated heat kernel in equation 25 on a 2D sphere for different k𝑘kitalic_k values. Above each visualization we show the relative error of the conditional score function: 𝔼pt(xt|x0)[xtlogp~t(xt|x0)xtlogpt(xt|x0)/xtlogpt(xt|x0)]\mathbb{E}_{p_{t}(x_{t}|x_{0})}\left[\left\|\nabla_{x_{t}}\log\tilde{p}_{t}(x_% {t}|x_{0})-\nabla_{x_{t}}\log p_{t}(x_{t}|x_{0})\right\|/\left\|\nabla_{x_{t}}% \log p_{t}(x_{t}|x_{0})\right\|\right]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∥ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ / ∥ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ ] where x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is located at the top of the sphere. We see that especially at small t𝑡titalic_t values (near data distribution), there is a significant error in the score function approximation even when using hundreds of eigenfunctions.
t=0.1𝑡0.1t=0.1italic_t = 0.1              t=0.05𝑡0.05t=0.05italic_t = 0.05              t=0.01𝑡0.01t=0.01italic_t = 0.01 Refer to caption aaa k=3𝑘3k=3italic_k = 3            k=8𝑘8k=8italic_k = 8            k=15𝑘15k=15italic_k = 15           k=35𝑘35k=35italic_k = 35          k=120𝑘120k=120italic_k = 120          k=440𝑘440k=440italic_k = 440

t=0.1𝑡0.1t=0.1italic_t = 0.1              t=0.05𝑡0.05t=0.05italic_t = 0.05              t=0.01𝑡0.01t=0.01italic_t = 0.01 Refer to caption aaa k=3𝑘3k=3italic_k = 3            k=8𝑘8k=8italic_k = 8            k=15𝑘15k=15italic_k = 15           k=35𝑘35k=35italic_k = 35          k=120𝑘120k=120italic_k = 120          k=440𝑘440k=440italic_k = 440

Figure 10: A visualization of the approximated heat kernel in equation 25 on a 2D sphere for different k𝑘kitalic_k values. Above each visualization we show the relative error of the conditional score function: 𝔼pt(xt|x0)[xtlogp~t(xt|x0)xtlogpt(xt|x0)/xtlogpt(xt|x0)]\mathbb{E}_{p_{t}(x_{t}|x_{0})}\left[\left\|\nabla_{x_{t}}\log\tilde{p}_{t}(x_% {t}|x_{0})-\nabla_{x_{t}}\log p_{t}(x_{t}|x_{0})\right\|/\left\|\nabla_{x_{t}}% \log p_{t}(x_{t}|x_{0})\right\|\right]blackboard_E start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∥ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ / ∥ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ ] where x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is located at the top of the sphere. We see that especially at small t𝑡titalic_t values (near data distribution), there is a significant error in the score function approximation even when using hundreds of eigenfunctions.
Figure 11: Spectral distances (biharmonic), i.e., d(x,)d𝑥\textrm{d}(x,\cdot)d ( italic_x , ⋅ ) where x𝑥xitalic_x is at the top of the sphere, are visualized using isolines. We see that it is very easy to satisfy the properties of a premetric with very few eigenfunctions, importantly the Positive property: d(x,y)=0d𝑥𝑦0\textrm{d}(x,y)=0d ( italic_x , italic_y ) = 0 iff x=y𝑥𝑦x=yitalic_x = italic_y, which allows us to concentrate probability perfectly at every point of the manifold, and the Non-degenerate property d(x,y)0d𝑥𝑦0\nabla\textrm{d}(x,y)\neq 0∇ d ( italic_x , italic_y ) ≠ 0 iff xy𝑥𝑦x\neq yitalic_x ≠ italic_y (satisfied here everywhere except at the antipodal point) that allows to properly define the flow in equation 13 almost everywhere.

Sufficiency with finite k𝑘kitalic_k. One may wonder if we pay any approximation costs when using finite k𝑘kitalic_k; the answer is no. In fact, we only need as many eigenfunctions as it takes to be able to distinguish (almost) every pair of points on the manifold. Put differently, we don’t need to compute the spectral distances perfectly, only sufficiently enough that the conditions of our premetric are satisfied. Regarding the question of what k𝑘kitalic_k is enough? This is only understood partially: for local neighborhoods the number of required eigenfunctions is the manifold dimension (but not necessarily the first ones), a property proven in (Jones et al., 2008, Theorem 2). Nevertheless, the use of spectral distances computed with k𝑘kitalic_k smallest eigenvalues is equivalent to computing Euclidean distances in a k𝑘kitalic_k-dimensional Euclidean embedding using the same eigenfunctions; this embedding is known to preserve neighborhoods optimally (Belkin & Niyogi, 2003).

As comparison, Riemannian score-based generative models (De Bortoli et al., 2022) suggests a heat kernel approximation

p~t(xt|x0)=i=0keλitφi(x0)φi(xt),subscript~𝑝𝑡conditionalsubscript𝑥𝑡subscript𝑥0superscriptsubscript𝑖0𝑘superscript𝑒subscript𝜆𝑖𝑡subscript𝜑𝑖subscript𝑥0subscript𝜑𝑖subscript𝑥𝑡\tilde{p}_{t}(x_{t}|x_{0})=\sum_{i=0}^{k}e^{-\lambda_{i}t}\varphi_{i}(x_{0})% \varphi_{i}(x_{t}),over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (25)

resulting in the approximated conditional score function

xtlogp~t(xt|x0)=xtlogi=0keλitφi(x0)φi(xt).subscriptsubscript𝑥𝑡subscript~𝑝𝑡conditionalsubscript𝑥𝑡subscript𝑥0subscriptsubscript𝑥𝑡superscriptsubscript𝑖0𝑘superscript𝑒subscript𝜆𝑖𝑡subscript𝜑𝑖subscript𝑥0subscript𝜑𝑖subscript𝑥𝑡\nabla_{x_{t}}\log\tilde{p}_{t}(x_{t}|x_{0})=\nabla_{x_{t}}\log\sum_{i=0}^{k}e% ^{-\lambda_{i}t}\varphi_{i}(x_{0})\varphi_{i}(x_{t}).∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) . (26)

However, equation 25 is only correct as k𝑘k\rightarrow\inftyitalic_k → ∞. This is manifested in practice as large amounts of error even when using hundreds of eigenfunctions. In Figure 11, we visualize the heat kernel approximation in equation 25 as well as relative error in the score function, taken in expectation w.r.t. pt(x|x0)subscript𝑝𝑡conditional𝑥subscript𝑥0p_{t}(x|x_{0})italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), i.e., the relative error is weighted higher in regions where we will evaluate the conditional score function during training. We see that at small time values close to the data distribution, the conditional score function may not even have the first significant digit correct (in expectation).

In contrast, we visualize the spectral distances using the biharmonic formulation in Figure 11, where we see that we already have an extremely accurate premetric with very few eigenfunctions. In particular, the required properties of a premetric are satisfied even for just k=3𝑘3k=3italic_k = 3, roughly corresponding to the manifold dimension. Higher values of k𝑘kitalic_k simply refine the spectral distance but are not necessary.

G.2 Manifolds with Boundary

In considering general geometries, we also consider the case where {\mathcal{M}}caligraphic_M has a boundary, denoted \partial{\mathcal{M}}∂ caligraphic_M. In this case, we need to add another condition to our premetric to make sure ut(x|x1)subscript𝑢𝑡conditional𝑥subscript𝑥1u_{t}(x|x_{1})italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) will not flow particles outside the manifold. Let n(x)Tx𝑛𝑥subscript𝑇𝑥n(x)\in T_{x}{\mathcal{M}}italic_n ( italic_x ) ∈ italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M denote the interior-pointing normal direction at a boundary point x𝑥x\in\partial{\mathcal{M}}italic_x ∈ ∂ caligraphic_M. We add the following condition to our premetric:

  1. 4.

    Boundary: d(x,y),n(x)g0subscriptd𝑥𝑦𝑛𝑥𝑔0\left\langle\nabla\textrm{d}(x,y),n(x)\right\rangle_{g}\leq 0⟨ ∇ d ( italic_x , italic_y ) , italic_n ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ≤ 0, y,xformulae-sequencefor-all𝑦𝑥\forall y\in{\mathcal{M}},x\in\partial{\mathcal{M}}∀ italic_y ∈ caligraphic_M , italic_x ∈ ∂ caligraphic_M.

If the premetric satisfies this condition, then the conditional VF in equation 13 satisfies

ut(x|x1),n(x)g0subscriptsubscript𝑢𝑡conditional𝑥subscript𝑥1𝑛𝑥𝑔0\left\langle u_{t}(x|x_{1}),n(x)\right\rangle_{g}\geq 0⟨ italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_n ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ≥ 0

implying that the conditional vector field does not point outwards on the boundary of the manifold.

Spectral distances at boundary points. In case {\mathcal{M}}caligraphic_M has boundary we want to make sure the spectral distances in equation 17 satisfy the boundary condition. To ensure this, we can simply solve eigenfunctions φisubscript𝜑𝑖\varphi_{i}italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT using the natural, or Neumann, boundary conditions, i.e., their normal derivative at boundary points vanish, and we have gφi(x),n(x)g=0subscriptsubscript𝑔subscript𝜑𝑖𝑥𝑛𝑥𝑔0\left\langle\nabla_{g}\varphi_{i}(x),n(x)\right\rangle_{g}=0⟨ ∇ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT italic_φ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) , italic_n ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = 0 for all x𝑥x\in\partial{\mathcal{M}}italic_x ∈ ∂ caligraphic_M. This property implies that xdw(x,y)2,n(x)g=0subscriptsubscript𝑥subscriptd𝑤superscript𝑥𝑦2𝑛𝑥𝑔0\left\langle\nabla_{x}\textrm{d}_{w}(x,y)^{2},n(x)\right\rangle_{g}=0⟨ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT d start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_x , italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_n ( italic_x ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = 0, satisfying the boundary condition of the premetric.

Appendix H Experiment Details

Training setup. All experiments are run on a single NVIDIA V100 GPU with 32GB memory. We tried our best to keep to the same training setup as prior works (Mathieu & Nickel, 2020; De Bortoli et al., 2022; Huang et al., 2022); however, as their exact data splits were not available, we used our own random splits. We followed their procedure and split the data according to 80% train, 10% val, and 10% test. We used seeds values of 0-4 for our five runs. We used the validation set for early stop** based on the validation NLL, and then only computed the test NLL using the checkpoint that achieved the best validation NLL. We used standard multilayer perceptron for parameterizing vector fields where time is concatenated as an input to the neural network. We generally used 512 hidden units and tuned the number of layers for each type of experiment, ranging from 6 to 12 layers. We used the Swish activation function (Ramachandran et al., 2017) with a learnable parameter. We used Adam with a learning rate of 1e-4 and an exponential moving averaging on the weights (Polyak & Juditsky, 1992) with a decay of 0.999 for all of our experiments.

High-dimensional tori. We use the same setup as De Bortoli et al. (2022). The data distribution is a wrapped Gaussian on the high-dimensional tori with a uniformly sampled mean and a scale of 0.2. We use a MLP with 3 hidden layers of size 512 to parameterize vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, train for 50000 iterations with a batch size of 512. We then report the log-likelihood per dimension (in bits) on 20000 newly sampled data points.

Triangular meshes. We use the exact open source mesh for Spot the Cow. For the Stanford Bunny, we downsample the mesh to 5000 triangles and work with this downsampled mesh. For constructing target distributions, we compute the eigenfunctions on a 3-times upsampled version of the mesh, threshold the k𝑘kitalic_k-th eigenfunction at zero, then normalized to construct the target distribution. This target distribution is uniform on each upsampled triangle, and further weighted by the area of each triangle. On the actual mesh we work with, this creates complex non-uniform distributions on each triangle. We also normalize the mesh so that points always lie in the range of (-1, 1).

Maze manifolds. We represent each maze manifold using triangular meshes in 2D. Each cell is represented using a mesh of 8x8 squares, with each square represented as two triangles. If two neighboring cells are connected, then we connect them using either 2x8 or 8x2 squares; if two neighboring cells are not connected (i.e., there is a wall), then we simply do not connect their meshes, resulting in boundaries on the manifold. This produces manifolds represented by triangular meshes where all the triangles are the same size. We randomly create maze structures based on a breadth-first-search algorithm, represent these using meshes, and we then normalize the mesh so that points lie in the range of (0, 1).

Table 5: Riemannian manifolds with known geodesics that we use in our experiments. The operator direct-sum\oplus denotes Möbius addition. Exponential maps, logarithm maps, and inner products are used during training with Riemannian Conditional Flow Matching. Note x,y𝑥𝑦x,y\in{\mathcal{M}}italic_x , italic_y ∈ caligraphic_M, u,vTx𝑢𝑣subscript𝑇𝑥u,v\in T_{x}{\mathcal{M}}italic_u , italic_v ∈ italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M, and g2=,g\left\|\cdot\right\|^{2}_{g}=\langle\cdot,\cdot\rangle_{g}∥ ⋅ ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = ⟨ ⋅ , ⋅ ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT. The last column log|g(x)|𝑔𝑥\log\left|g(x)\right|roman_log | italic_g ( italic_x ) | denotes the logarithm of the absolute value of the determinant of the metric tensor at x𝑥x\in{\mathcal{M}}italic_x ∈ caligraphic_M; this is used during log-likelihood computation (see equations 38 and 34).
Manifold {\mathcal{M}}caligraphic_M expx(u)subscriptexp𝑥𝑢\text{exp}_{x}(u)exp start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_u ) logx(y)subscriptlog𝑥𝑦\text{log}_{x}(y)log start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_y ) u,vgsubscript𝑢𝑣𝑔\langle u,v\rangle_{g}⟨ italic_u , italic_v ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT log|g(x)|𝑔𝑥\log\left|g(x)\right|roman_log | italic_g ( italic_x ) |
N𝑁Nitalic_N-D sphere {xN+1:x2=1}conditional-set𝑥superscript𝑁1subscriptnorm𝑥21\{x\in\mathbb{R}^{N+1}:\left\|x\right\|_{2}=1\}{ italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N + 1 end_POSTSUPERSCRIPT : ∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 } xcos(u2)+uu2sin(u2)𝑥subscriptnorm𝑢2𝑢subscriptnorm𝑢2subscriptnorm𝑢2x\cos\left(\left\|u\right\|_{2}\right)+\frac{u}{\left\|u\right\|_{2}}\sin\left% (\left\|u\right\|_{2}\right)italic_x roman_cos ( ∥ italic_u ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) + divide start_ARG italic_u end_ARG start_ARG ∥ italic_u ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG roman_sin ( ∥ italic_u ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) arccos(x,y)Px(yx)Px(yx)2arccos𝑥𝑦subscript𝑃𝑥𝑦𝑥subscriptnormsubscript𝑃𝑥𝑦𝑥2\operatorname{arccos}\left(\langle x,y\rangle\right)\frac{P_{x}(y-x)}{\left\|P% _{x}(y-x)\right\|_{2}}roman_arccos ( ⟨ italic_x , italic_y ⟩ ) divide start_ARG italic_P start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_y - italic_x ) end_ARG start_ARG ∥ italic_P start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_y - italic_x ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG u,v𝑢𝑣\langle u,v\rangle⟨ italic_u , italic_v ⟩ 00
N𝑁Nitalic_N-D flat tori [0,2π]Nsuperscript02𝜋𝑁[0,2\pi]^{N}[ 0 , 2 italic_π ] start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT (x+u)%(2π)percent𝑥𝑢2𝜋(x+u)\;\%\;(2\pi)( italic_x + italic_u ) % ( 2 italic_π ) arctan2(sin(yx),cos(yx))arctan2𝑦𝑥𝑦𝑥\operatorname{arctan2}\left(\sin(y-x),\cos(y-x)\right)arctan2 ( roman_sin ( italic_y - italic_x ) , roman_cos ( italic_y - italic_x ) ) u,v𝑢𝑣\langle u,v\rangle⟨ italic_u , italic_v ⟩ 00
N𝑁Nitalic_N-D Hyperbolic {xN:x2<1}conditional-set𝑥superscript𝑁subscriptnorm𝑥21\{x\in\mathbb{R}^{N}:\left\|x\right\|_{2}<1\}{ italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT : ∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < 1 } x(tanh(u21x22)uu2)direct-sum𝑥subscriptnorm𝑢21subscriptsuperscriptnorm𝑥22𝑢subscriptnorm𝑢2x\oplus\left(\tanh\left(\frac{\left\|u\right\|_{2}}{1-\left\|x\right\|^{2}_{2}% }\right)\frac{u}{\left\|u\right\|_{2}}\right)italic_x ⊕ ( roman_tanh ( divide start_ARG ∥ italic_u ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG 1 - ∥ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) divide start_ARG italic_u end_ARG start_ARG ∥ italic_u ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) (1x22)tanh1(xy2)xyxy21subscriptsuperscriptnorm𝑥22superscript1subscriptnormdirect-sum𝑥𝑦2direct-sum𝑥𝑦subscriptnormdirect-sum𝑥𝑦2\left(1-\left\|x\right\|^{2}_{2}\right)\tanh^{-1}\left(\left\|-x\oplus y\right% \|_{2}\right)\frac{-x\oplus y}{\left\|-x\oplus y\right\|_{2}}( 1 - ∥ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) roman_tanh start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ∥ - italic_x ⊕ italic_y ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) divide start_ARG - italic_x ⊕ italic_y end_ARG start_ARG ∥ - italic_x ⊕ italic_y ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG 4(1x22)2u,v4superscript1superscriptsubscriptnorm𝑥222𝑢𝑣\frac{4}{(1-\left\|x\right\|_{2}^{2})^{2}}\langle u,v\rangledivide start_ARG 4 end_ARG start_ARG ( 1 - ∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ⟨ italic_u , italic_v ⟩ Nlog21x22𝑁21superscriptsubscriptnorm𝑥22N\log\frac{2}{1-\left\|x\right\|_{2}^{2}}italic_N roman_log divide start_ARG 2 end_ARG start_ARG 1 - ∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
N×N𝑁𝑁N\times Nitalic_N × italic_N SPD matrices X12exp{X12UX12}X12superscript𝑋12superscript𝑋12𝑈superscript𝑋12superscript𝑋12X^{\frac{1}{2}}\exp\{X^{-\frac{1}{2}}UX^{-\frac{1}{2}}\}X^{\frac{1}{2}}italic_X start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT roman_exp { italic_X start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_U italic_X start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT } italic_X start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT X12log{X12YX12}X12superscript𝑋12superscript𝑋12𝑌superscript𝑋12superscript𝑋12X^{\frac{1}{2}}\log\{X^{-\frac{1}{2}}YX^{-\frac{1}{2}}\}X^{\frac{1}{2}}italic_X start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT roman_log { italic_X start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_Y italic_X start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT } italic_X start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT tr(X1UX1V)trsuperscript𝑋1𝑈superscript𝑋1𝑉\textrm{tr}\left(X^{-1}UX^{-1}V\right)tr ( italic_X start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U italic_X start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_V ) N(N1)2log(2)+(N+1)logdetX𝑁𝑁122𝑁1𝑋\frac{N(N-1)}{2}\log(2)+(N+1)\log\det Xdivide start_ARG italic_N ( italic_N - 1 ) end_ARG start_ARG 2 end_ARG roman_log ( 2 ) + ( italic_N + 1 ) roman_log roman_det italic_X

Vector field parameterization. We parameterize vector fields as neural networks in the ambient space and project onto the tangent space at every x𝑥xitalic_x. That is, similarly to (Rozen et al., 2021) we model

vt(x)=g(x)12Pπ(x)vθ(t,π(x))subscript𝑣𝑡𝑥𝑔superscript𝑥12subscript𝑃𝜋𝑥subscript𝑣𝜃𝑡𝜋𝑥v_{t}(x)=g(x)^{-\frac{1}{2}}P_{\pi(x)}v_{\theta}(t,\pi(x))italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = italic_g ( italic_x ) start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_π ( italic_x ) end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_t , italic_π ( italic_x ) ) (27)

where π𝜋\piitalic_π is the projection operator onto the manifold, i.e.,

π(x)=argminyxyg,𝜋𝑥subscript𝑦subscriptnorm𝑥𝑦𝑔\pi(x)=\arg\min_{y\in{\mathcal{M}}}\left\|x-y\right\|_{g},italic_π ( italic_x ) = roman_arg roman_min start_POSTSUBSCRIPT italic_y ∈ caligraphic_M end_POSTSUBSCRIPT ∥ italic_x - italic_y ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , (28)

and Pysubscript𝑃𝑦P_{y}italic_P start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT is the orthogonal projection onto the tangent space at y𝑦yitalic_y.

We also normalize the vector field using g(x)12𝑔superscript𝑥12g(x)^{-\frac{1}{2}}italic_g ( italic_x ) start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT, which cancels out the effect of g𝑔gitalic_g on the Riemannian norm and makes standard neural network parameterization more robust to changes in the metric, i.e.,

g12vg2=(g12v)𝖳g(g12v)=v𝖳v=v22.superscriptsubscriptnormsuperscript𝑔12𝑣𝑔2superscriptsuperscript𝑔12𝑣𝖳𝑔superscript𝑔12𝑣superscript𝑣𝖳𝑣superscriptsubscriptnorm𝑣22\left\|g^{-\frac{1}{2}}v\right\|_{g}^{2}=(g^{-\frac{1}{2}}v)^{\mkern-1.5mu% \mathsf{T}}{}g(g^{-\frac{1}{2}}v)=v^{\mkern-1.5mu\mathsf{T}}{}v=\left\|v\right% \|_{2}^{2}.∥ italic_g start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_v ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = ( italic_g start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_v ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_g ( italic_g start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_v ) = italic_v start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_v = ∥ italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (29)

We found this bypasses the need to construct manifold-specific initialization schemes for our neural networks and leads to more stable training.

Log-likelihood computation. We solve for the log-density logp1(x)subscript𝑝1𝑥\log p_{1}(x)roman_log italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ), for an arbitrary test sample x𝑥x\in{\mathcal{M}}italic_x ∈ caligraphic_M, by using the instantaneous change of variables (Chen et al., 2018), namely solve the ODE

ddt(xtft(x))=(vt(xt)divg(vt)(xt)),𝑑𝑑𝑡matrixsubscript𝑥𝑡subscript𝑓𝑡𝑥matrixsubscript𝑣𝑡subscript𝑥𝑡subscriptdiv𝑔subscript𝑣𝑡subscript𝑥𝑡\frac{d}{dt}\begin{pmatrix}x_{t}\\ f_{t}(x)\end{pmatrix}=\begin{pmatrix}v_{t}(x_{t})\\ -\text{div}_{g}\left(v_{t}\right)(x_{t})\end{pmatrix},divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG ( start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW end_ARG ) = ( start_ARG start_ROW start_CELL italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL - div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ) , (30)

where divgsubscriptdiv𝑔\text{div}_{g}div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT is the Riemannian divergence. We solve backwards in time from t=1𝑡1t=1italic_t = 1 to time t=0𝑡0t=0italic_t = 0 with the initial conditions

(x1f1(x))=(x0)matrixsubscript𝑥1subscript𝑓1𝑥matrix𝑥0\begin{pmatrix}x_{1}\\ f_{1}(x)\end{pmatrix}=\begin{pmatrix}x\\ 0\end{pmatrix}( start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW end_ARG ) = ( start_ARG start_ROW start_CELL italic_x end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) (31)

and compute the desired log-density at x𝑥xitalic_x via

logp1(x)=logp0(x0)f0(x).subscript𝑝1𝑥subscript𝑝0subscript𝑥0subscript𝑓0𝑥\log p_{1}(x)=\log p_{0}(x_{0})-f_{0}(x).roman_log italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) = roman_log italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) . (32)

For manifolds that are embedded in an ambient Euclidean space (e.g., hypersphere, flat tori, triangular meshes), the parameterization in equation 27 allows us to compute the Riemannian divergence directly in the ambient space (Rozen et al., 2021, Lemma 2). That is,

divg(vt)=divE(vt)=ivt(x)ixi.subscriptdiv𝑔subscript𝑣𝑡subscriptdiv𝐸subscript𝑣𝑡subscript𝑖subscript𝑣𝑡subscript𝑥𝑖subscript𝑥𝑖\mathrm{div}_{g}\left(v_{t}\right)=\mathrm{div}_{E}\left(v_{t}\right)=\sum_{i}% \frac{\partial v_{t}(x)_{i}}{\partial x_{i}}.roman_div start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = roman_div start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG . (33)

For general manifolds with metric tensor g𝑔gitalic_g (e.g., Poincaré ball model of hyperbolic manifold, the manifold of symmetric positive definite matrices), we compute the Riemannian divergence as

divg(vt)=divE(vt)+12vt𝖳Elogdetgsubscriptdiv𝑔subscript𝑣𝑡subscriptdiv𝐸subscript𝑣𝑡12superscriptsubscript𝑣𝑡𝖳subscript𝐸𝑔\text{div}_{g}\left(v_{t}\right)=\text{div}_{E}\left(v_{t}\right)+\frac{1}{2}v% _{t}^{\mkern-1.5mu\mathsf{T}}{}\nabla_{E}\log\det gdiv start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = div start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT roman_log roman_det italic_g (34)

where divEsubscriptdiv𝐸\text{div}_{E}div start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT is the standard Euclidean divergence and E=(x1,,xd)Tsubscript𝐸superscriptsubscript𝑥1subscript𝑥𝑑𝑇\nabla_{E}=(\frac{\partial}{\partial x_{1}},\ldots,\frac{\partial}{\partial x_% {d}})^{T}∇ start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT = ( divide start_ARG ∂ end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG , … , divide start_ARG ∂ end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT is the Euclidean gradient.

Wrapped distributions. An effective way to define a simple distribution p𝒫𝑝𝒫p\in\mathcal{P}italic_p ∈ caligraphic_P over a manifold {\mathcal{M}}caligraphic_M of dimension d𝑑ditalic_d is pushing some simple prior p~~𝑝\tilde{p}over~ start_ARG italic_p end_ARG defined on some euclidean space dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT via a chart ϕ:d:italic-ϕsuperscript𝑑\phi:\mathbb{R}^{d}\rightarrow{\mathcal{M}}italic_ϕ : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → caligraphic_M; for example ϕ=expx:Tx:italic-ϕsubscript𝑥subscript𝑇𝑥\phi=\exp_{x}:T_{x}{\mathcal{M}}\rightarrow{\mathcal{M}}italic_ϕ = roman_exp start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT : italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_M → caligraphic_M the Riemannian exponential map. Generating a sample xp(x)similar-to𝑥𝑝𝑥x\sim p(x)italic_x ∼ italic_p ( italic_x ) is done by drawing a sample zp~(z)similar-to𝑧~𝑝𝑧z\sim\tilde{p}(z)italic_z ∼ over~ start_ARG italic_p end_ARG ( italic_z ), zd𝑧superscript𝑑z\in\mathbb{R}^{d}italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, and computing x=ϕ(z)𝑥italic-ϕ𝑧x=\phi(z)italic_x = italic_ϕ ( italic_z ). To compute the probability density p(x)𝑝𝑥p(x)italic_p ( italic_x ) at some point x𝑥x\in{\mathcal{M}}italic_x ∈ caligraphic_M, we integrate over some arbitrary domain ΩΩ\Omega\subset{\mathcal{M}}roman_Ω ⊂ caligraphic_M,

Ωp(x)𝑑volx=ϕ1(Ω)p(ϕ(z))detg(z)𝑑z,subscriptΩ𝑝𝑥differential-dsubscriptvol𝑥subscriptsuperscriptitalic-ϕ1Ω𝑝italic-ϕ𝑧𝑔𝑧differential-d𝑧\int_{\Omega}p(x)d\mathrm{vol}_{x}=\int_{\phi^{-1}(\Omega)}p(\phi(z))\sqrt{% \det g(z)}dz,∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_p ( italic_x ) italic_d roman_vol start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = ∫ start_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT italic_p ( italic_ϕ ( italic_z ) ) square-root start_ARG roman_det italic_g ( italic_z ) end_ARG italic_d italic_z , (35)

where gij(z)=iϕ(z),jϕ(z)gsubscript𝑔𝑖𝑗𝑧subscriptsubscript𝑖italic-ϕ𝑧subscript𝑗italic-ϕ𝑧𝑔g_{ij}(z)=\left\langle\partial_{i}\phi(z),\partial_{j}\phi(z)\right\rangle_{g}italic_g start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( italic_z ) = ⟨ ∂ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ϕ ( italic_z ) , ∂ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_ϕ ( italic_z ) ⟩ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, i,j[d]𝑖𝑗delimited-[]𝑑i,j\in[d]italic_i , italic_j ∈ [ italic_d ], is the Riemannian metric tensor in local coordinates, and iϕ(z)=ϕ(z)zisubscript𝑖italic-ϕ𝑧italic-ϕ𝑧subscript𝑧𝑖\partial_{i}\phi(z)=\frac{\partial\phi(z)}{\partial z_{i}}∂ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ϕ ( italic_z ) = divide start_ARG ∂ italic_ϕ ( italic_z ) end_ARG start_ARG ∂ italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG. From this integral we get that

p~(z)=p(ϕ(z))detg(z)~𝑝𝑧𝑝italic-ϕ𝑧𝑔𝑧\tilde{p}(z)=p(\phi(z))\sqrt{\det g(z)}over~ start_ARG italic_p end_ARG ( italic_z ) = italic_p ( italic_ϕ ( italic_z ) ) square-root start_ARG roman_det italic_g ( italic_z ) end_ARG (36)

and therefore

p(x)=p~(ϕ1(x))detg(ϕ1(x)).𝑝𝑥~𝑝superscriptitalic-ϕ1𝑥𝑔superscriptitalic-ϕ1𝑥p(x)=\frac{\tilde{p}(\phi^{-1}(x))}{\sqrt{\det g(\phi^{-1}(x))}}.italic_p ( italic_x ) = divide start_ARG over~ start_ARG italic_p end_ARG ( italic_ϕ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) end_ARG start_ARG square-root start_ARG roman_det italic_g ( italic_ϕ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) end_ARG end_ARG . (37)

and in log space

logp(x)=logp~(ϕ1(x))12logdetg(ϕ1(x)).𝑝𝑥~𝑝superscriptitalic-ϕ1𝑥12𝑔superscriptitalic-ϕ1𝑥\log p(x)=\log\tilde{p}(\phi^{-1}(x))-\frac{1}{2}\log\det g(\phi^{-1}(x)).roman_log italic_p ( italic_x ) = roman_log over~ start_ARG italic_p end_ARG ( italic_ϕ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det italic_g ( italic_ϕ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) . (38)

Numerical accuracy. On the hypersphere, NLL values were computed using an adaptive step size ODE solver (dopri5) with tolerances of 1e-7. On the high dimensional flat torus and SPD manifolds, we use the same solver but with tolerances of 1e-5. We always check that the solution does not leave the manifold by ensuring the difference between the solution and its projection onto the manifold is numerically negligible.

On general geometries represented using discrete triangular meshes, we used 1000 Euler steps with a projection after every step for evaluation (NLL computation and sampling after training). During training on general geometries, we solve for the path xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT using 300 Euler steps with projection after every step. In order to avoid division by zero during the computation of the conditional vector field in equation 13, we solve xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from t=0𝑡0t=0italic_t = 0 to t=1ε𝑡1𝜀t=1-\varepsilonitalic_t = 1 - italic_ε, where ε𝜀\varepsilonitalic_ε is taken to be 1e-5; this effectively flows the base distribution to a non-degenerate distribution around x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT that approximates the Dirac distribution, similar to the role of σminsubscript𝜎min\sigma_{\text{min}}italic_σ start_POSTSUBSCRIPT min end_POSTSUBSCRIPT of Lipman et al. (2023).

See Hairer et al. (2006) and Hairer (2011) for overviews on ODE solving on manifolds.

Appendix I Empirical Runtime Estimates

During Riemannian Flow Matching training, there are two main computational considerations: (i) solving for xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and (ii) computing the training objective. In comparison to diffusion models, Conditional Flow Matching has the clear advantage in (ii), since we don’t need to estimate a conditional score function through an infinite series (as in DSM), nor do we require divergence estimation (as in ISM). In the following, we focus on runtime for different ways of (i) solving for xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT:

Simulation of ODE/SDE (200 steps) on a flat torus: 6.36 iterations / second6.36 iterations / second\displaystyle 6.36\text{ iterations / second}6.36 iterations / second
Simulation-free on a flat torus: 104.04 iterations / second104.04 iterations / second\displaystyle 104.04\text{ iterations / second}104.04 iterations / second
Simulation of ODE/SDE (200 steps) on the bunny mesh: 0.422 iterations / second0.422 iterations / second\displaystyle 0.422\text{ iterations / second}0.422 iterations / second

These numbers were benchmarked on a Tesla V100 GPU, with batch size 64. The runtime is for the full training loop, but the main difference between the three lines is how xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is solved while all others (i.e. architecture) are fixed.

Generally, the bunny and other mesh manifolds are more expensive due to the projection operator applied after every step, which we implemented rather naively for meshes. However, comparing to iteratively solving an ODE/SDE even on simple manifolds, we see a significant speedup of roughly 17x even when taking into account the full training loop (including gradient descent etc). This shows the efficiency gains from using simulation-free training over simulation-based.

Appendix J Additional Experiments

Here we consider manifolds with constrained domains and non-trivial metric tensors, specifically, a hyperbolic space and a manifold of symmetric positive definite matrices, equipped with their standard Riemannian metrics. See Table 5 for a summary of the geometries of these manifolds.

J.1 Hyperbolic Manifold

We use the Poincaré disk model for representing a hyperbolic space in 2-D. Figure 12 visualizes geodesic paths originating from a single point on the manifold, a learned CNF using Riemannian Conditional Flow Matching, and samples from the learned CNF. Our learned CNF respects the geometry of the manifold and transports samples along geodesic paths, recovering a near-optimal transport map in line with the Riemannian metric. Similarly, due to the use of this metric, the CNF never transports outside of the manifold.

Refer to caption
(a) Geodesic paths
Refer to caption
(b) Learned flow
Refer to caption
(c) Model samples
Figure 12: (a) Geodesic paths on a hyperbolic manifold (represented using the Poincaré disk model) originating from a single point (blue star). (b) Learned CNF where blue samples are from p(x0)𝑝subscript𝑥0p(x_{0})italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and orange samples are from q(x1)𝑞subscript𝑥1q(x_{1})italic_q ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). The CNF respects the geometry of the hyperbolic manifold, and learns to transport along geodesic paths.

J.2 Manifold of Symmetric Positive Matrices

We use the space of symmetric positive definite (SPD) matrices with the Riemannian metric (Moakher & Batchelor, 2006). We construct datasets using electroencephalography (EEG) data collected by Blankertz et al. (2007); Brunner et al. (2008); Leeb et al. (2008) for a Brain-Computer Interface (BCI) competition. We then computed the covariance matrices of these signals, following standard preprocessing procedure for analyzing EEG signals (Barachant et al., 2013).

In Table 6, we report estimates of negative log-likelihood (NLL) and the percentage of simulated samples that are valid SPD matrices (i.e., samples which lie on the manifold). We ablate and note the importance of the Riemannian geodesic and the Riemannian norm during training.

Refer to caption
Figure 13: Comparison of extrapolated geodesic paths on SPD manifold, visualized as a convex cone.
(black) Riemannian geodesic.
(violet) Euclidean geodesic.

Riemannian geodesic.

We compare using Riemannian geodesics (i.e., setting the premetric to be the geodesic distance) and Euclidean geodesics (i.e., setting the premetric to be the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT distance). In comparing between different paths, we find that the Riemannian one generally performs better as it respects the geometry of the underlying manifold. Figure 13 visualizes the space of 2×2222\times 22 × 2 SPD matrices as a convex cone. It displays how the Riemannian geodesic behaves—its flow always becomes perpendicular to the boundary when it gets close to boundary and therefore does not leave the manifold—whereas the Euclidean geodesic ignores this geometry.

Riemannian norm.

We also compare between using the Riemannian norm (i.e., g2\smash{\left\|\cdot\right\|_{g}^{2}}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for training and using the Euclidean norm (i.e., 22\smash{\left\|\cdot\right\|_{2}^{2}}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT). Theoretically, the choice of norm does not affect the optimal vt(x)subscript𝑣𝑡𝑥v_{t}(x)italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ), which will equal to ut(x)subscript𝑢𝑡𝑥u_{t}(x)italic_u start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ); however, when vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is modeled with a limited capacity neural network, the choice of norm can be very important as it affects which regions the optimization focuses on (in particular, regions with a large metric tensor). In particular, on the SPD manifold, similar to hyperbolic, the metric tensor increases to infinity in regions where the matrix is close to being singular (i.e., ill-conditioned). We find that especially for larger SPD matrices, using the Riemannian norm is important to ensure vtsubscript𝑣𝑡v_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT does not leave the manifold during simulation (that is, the simulated result is still a SPD matrix).

Table 6: Test set evaluation on EEG datasets. Stdev. estimated over 3 runs.
BCI-IV-2b 6×\times×6 (21D) BCI-IV-2a 25×\times×25 (325D) BCI-IV-1 59×\times×59 (1770D)
Geodesic Norm NLL Valid NLL Valid NLL Valid
Euclidean Euclidean -61.58±plus-or-minus\pm±0.26 100±plus-or-minus\pm±0.00 -276.07±plus-or-minus\pm±0.66 81.23±plus-or-minus\pm±5.12 N/A 0±plus-or-minus\pm±0.00
Riemannian Euclidean -61.64±plus-or-minus\pm±0.22 100±plus-or-minus\pm±0.00 -277.06±plus-or-minus\pm±0.87 91.47±plus-or-minus\pm±5.91 N/A 0±plus-or-minus\pm±0.00
Euclidean Riemannian -52.22±plus-or-minus\pm±9.67 100±plus-or-minus\pm±0.00 -267.42±plus-or-minus\pm±5.87 100±plus-or-minus\pm±0.00 -1167.63±plus-or-minus\pm±40.53 100±plus-or-minus\pm±0.00
Riemannian Riemannian -61.76±plus-or-minus\pm±0.24 100±plus-or-minus\pm±0.00 -271.54±plus-or-minus\pm±1.17 100±plus-or-minus\pm±0.00 -1209.88±plus-or-minus\pm±53.55 100±plus-or-minus\pm±0.00