Last Update: July 1, 2024

On Statistical Rates and Provably Efficient Criteria
of Latent Diffusion Transformers (DiTs)

Jerry Yao-Chieh Hu†∗111[email protected]  Weimin Wu†∗222[email protected]  Zhuoru Li333[email protected]  Zhao Song444[email protected]  Han Liu†§555[email protected]

**footnotetext: These authors contributed equally to this work.
{}^{\dagger}\;start_FLOATSUPERSCRIPT † end_FLOATSUPERSCRIPTDepartment of Computer Science, Northwestern University, Evanston, IL 60208 USA
{}^{\ddagger}\;start_FLOATSUPERSCRIPT ‡ end_FLOATSUPERSCRIPTSchool of Mathematical Science, Fudan University, Yangpu, Shanghai 200433, China
{}^{\flat}\;start_FLOATSUPERSCRIPT ♭ end_FLOATSUPERSCRIPTAdobe Research, Seattle, WA 98103, USA
§§{}^{\S}\;start_FLOATSUPERSCRIPT § end_FLOATSUPERSCRIPTDepartment of Statistics and Data Science, Northwestern University, Evanston, IL 60208 USA

We investigate the statistical and computational limits of latent Diffusion Transformers (DiTs) under the low-dimensional linear latent space assumption. Statistically, we study the universal approximation and sample complexity of the DiTs score function, as well as the distribution recovery property of the initial data. Specifically, under mild data assumptions, we derive an approximation error bound for the score network of latent DiTs, which is sub-linear in the latent space dimension. Additionally, we derive the corresponding sample complexity bound and show that the data distribution generated from the estimated score function converges toward a proximate area of the original one. Computationally, we characterize the hardness of both forward inference and backward computation of latent DiTs, assuming the Strong Exponential Time Hypothesis (SETH). For forward inference, we identify efficient criteria for all possible latent DiTs inference algorithms and showcase our theory by pushing the efficiency toward almost-linear time inference. For backward computation, we leverage the low-rank structure within the gradient computation of DiTs training for possible algorithmic speedup. Specifically, we show that such speedup achieves almost-linear time latent DiTs training by casting the DiTs gradient as a series of chained low-rank approximations with bounded error. Under the low-dimensional assumption, we show that the convergence rate and the computational efficiency are both dominated by the dimension of the subspace, suggesting that latent DiTs have the potential to bypass the challenges associated with the high dimensionality of initial data.

1 Introduction

We investigate the statistical and computational limits of latent diffusion transformers (DiTs), assuming the data is supported on an unknown low-dimensional linear subspace. This analysis is not only practical but also timely. On one hand, DiTs have demonstrated revolutionary success in generative AI and digital creation by using Transformers as score networks (Esser et al., 2024; Ma et al., 2024; Chen et al., 2024; Mo et al., 2023; Peebles and Xie, 2023). On the other hand, they require significant computational resources (Liu et al., 2024), making them challenging to train outside of specialized industrial labs. Therefore, it is natural to ask whether it is possible to make them lighter and faster without sacrificing performance. Answering these questions requires a fundamental understanding of the DiT architecture. This work provides a timely theoretical analysis of the fundamental limits of DiT architecture, aided by the analytical feasibility provided by the low-dimensional data assumption.

Empirically, Latent Diffusion is a go-to design for effectiveness and computational efficiency (Rombach et al., 2022; Liu et al., 2021; Pope et al., 2021; Su and Wu, 2018). Theoretically, it is capable to host the assumption of low-dimensional data structure (see Assumption 2.1 for formal definition) for detailed analytical characterization (Chen et al., 2023a; Bortoli, 2022). In essence, diffusion models with low-dimensional data structures manifest a natural lower-dimensional diffusion process through encoder/decoder within a robust and informative latent representation feature space (Rombach et al., 2022; Pope et al., 2021). Such lower-dimensional diffusion improves computational efficiency by reducing data complexity without sacrificing essential information (Liu et al., 2021). With this assumption, Chen et al. (2023a) decompose the score function of U-Net based diffusion models into on-support and orthogonal components. This decomposition allows for the characterization of the distinct behaviors of the two components: the on-support component facilitates latent distribution learning, while the orthogonal component facilitates subspace recovery.

In our work, we utilize low-dimensional data structure assumption to explore statistical and computational limits of latent DiTs. Our analysis includes the characterizations of statistical rates and provably efficient criteria. Statistically, we pose two questions and provide a theory to characterize the statistical rates of latent DiT under the assumption of a low-dimensional data:

Question 1.

What is the approximation limit of using transformers to approximate the DiT score function, particularly in the low-dimensional data subspace?

Question 2.

How accurate is the estimation limit for such a score estimator in practical training scenarios? With the score estimator, how well can diffusion transformers recover the data distribution?

Computationally, the primary challenge of DiT lies in the transformer blocks’ quadratic complexity. This computational burden applies to both inference and training, even with latent diffusion. Thus, it is essential to design algorithms and methods to circumvent this Ω(L2)Ωsuperscript𝐿2\Omega(L^{2})roman_Ω ( italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) where L𝐿Litalic_L is the latent DiT sequence length. However, there are no formal results to support and characterize such algorithms. To address this gap, we pose the following questions and provide a fundamental theory to fully characterize the complexity of latent DiT under the low-dimensional linear subspace data assumption:

Question 3.

Is it possible to improve the Ω(L2)Ωsuperscript𝐿2\Omega(L^{2})roman_Ω ( italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) time complexity with a bounded approximation error for both forward and backward passes? What is the computational limit for such an improvement?

Contributions.

We study the fundamental limits of latent DiT. Our contributions are threefold:

  • Score Approximation. We address Question 1 by characterizing the approximation limit of matching the DiT score function with a transformer-based score estimator. Specifically, under mild data assumptions, we derive an approximation error bound for the score network, sub-linear in the latent space dimension (Theorem 3.1). These results not only explain the expressiveness of latent DiT (under mild assumptions) but also provide guidance for the structural configuration of the score network for practical implementations (Theorem 3.1).

  • Score and Distribution Estimation. We address Question 2 by exploring the limitations of score and distribution estimations of latent DiTs in practical training scenarios. Specifically, we provide an sample complexity bound for score estimation (Corollary 3.1.1), using norm-based covering number bound of transformer architecture. Additionally, we show that the learned score estimator is able to recover the initial data distribution (Corollary 3.1.2).

  • Provably Efficient Criteria and Existence of Almost Linear Time Algorithms. We address Question 3 by providing provably efficient criteria for latent DiTs in both forward inference and backward computation/training. For forward inference, we characterize all possible efficient DiT algorithms using a norm-based efficiency threshold for both conditional and unconditional generation (Proposition 4.1). Efficient algorithms, including almost-linear time algorithms (Proposition 4.2), are possible only below this threshold. For backward computation, we prove the existence of almost-linear time DiT training algorithms (Theorem 4.1) by utilizing the inherent low-rank structure in DiT gradients through a chained low-rank approximation.

Interestingly, both our statistical and computational results (C1-3) are dominated by the subspace dimension under the low-dimensional assumption, suggesting that latent DiT can potentially bypass the challenges associated with the high dimensionality of initial data.

Organization.

Section 2 includes background on score decomposition and Transformer-based score networks. Section 3 presents the statistical rates of DiTs. Section 4 provides provably efficient criteria. We defer discussions of related works to Appendix C due to space constraints.

Notations.

We use lower case letters to denote vectors, e.g., zD𝑧superscript𝐷z\in\mathbb{R}^{D}italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. z2subscriptnorm𝑧2\norm{z}_{2}∥ start_ARG italic_z end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and zsubscriptnorm𝑧\norm{z}_{\infty}∥ start_ARG italic_z end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT denote its Euclidean norm and Infinite norm respectively. We use upper case letters to denote matrix, e.g., Zd×L𝑍superscript𝑑𝐿Z\in\mathbb{R}^{d\times L}italic_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. Z2subscriptnorm𝑍2\norm{Z}_{2}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, Zopsubscriptnorm𝑍op\norm{Z}_{\rm op}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT roman_op end_POSTSUBSCRIPT, and ZFsubscriptnorm𝑍𝐹\norm{Z}_{F}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT denote the 2222-norm, operator norm and Frobenius norm respectively. Zp,qsubscriptnorm𝑍𝑝𝑞\norm{Z}_{p,q}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT italic_p , italic_q end_POSTSUBSCRIPT denotes the p,q𝑝𝑞p,qitalic_p , italic_q-norm where the p𝑝pitalic_p-norm is over columns and q𝑞qitalic_q-norm is over rows. Given a function f𝑓fitalic_f, let f(x)L2(f(x)22dx)1/2subscriptnorm𝑓𝑥superscript𝐿2superscriptsuperscriptsubscriptnorm𝑓𝑥22𝑥12\norm{f(x)}_{L^{2}}\coloneqq(\int\norm{f(x)}_{2}^{2}\differential x)^{1/2}∥ start_ARG italic_f ( italic_x ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≔ ( ∫ ∥ start_ARG italic_f ( italic_x ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_x ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT, and f()Lip=supxy(f(x)f(y)2/xy2)subscriptnorm𝑓𝐿𝑖𝑝subscriptsupremum𝑥𝑦subscriptnorm𝑓𝑥𝑓𝑦2subscriptnorm𝑥𝑦2\norm{f(\cdot)}_{Lip}=\sup_{x\neq y}(\norm{f(x)-f(y)}_{2}/\norm{x-y}_{2})∥ start_ARG italic_f ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L italic_i italic_p end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_x ≠ italic_y end_POSTSUBSCRIPT ( ∥ start_ARG italic_f ( italic_x ) - italic_f ( italic_y ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT / ∥ start_ARG italic_x - italic_y end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). With a distribution P𝑃Pitalic_P, we denote fL2(P)=(Pf(x)22dx)1/2subscriptnorm𝑓superscript𝐿2𝑃superscriptsubscript𝑃superscriptsubscriptnorm𝑓𝑥22𝑥12\norm{f}_{L^{2}(P)}=(\int_{P}\norm{f(x)}_{2}^{2}\differential x)^{1/2}∥ start_ARG italic_f end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P ) end_POSTSUBSCRIPT = ( ∫ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ∥ start_ARG italic_f ( italic_x ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_x ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT as the L2(P)superscript𝐿2𝑃L^{2}(P)italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P ) norm. Let fPsubscript𝑓𝑃f_{\sharp}Pitalic_f start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT italic_P be a pushforward measure, i.e., for any measurable ΩΩ\Omegaroman_Ω, (fP)(Ω)=P(f1(Ω))subscript𝑓𝑃Ω𝑃superscript𝑓1Ω(f_{\sharp}P)(\Omega)=P(f^{-1}(\Omega))( italic_f start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT italic_P ) ( roman_Ω ) = italic_P ( italic_f start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( roman_Ω ) ). We use ψ𝜓\psiitalic_ψ for (conditional) Gaussian density functions.

2 Background

This section reviews the ideas we built on, including an overview of diffusion models (Section 2.1), the score decomposition under the linear latent space assumption (Section 2.2), and the transformer backbone in DiT (Section 2.3).

2.1 Score-Matching Denoising Diffusion Models

We briefly review forward process, backward process and score matching in diffusion models.

Forward and Backward Process.

In the forward process, Diffusion models gradually add noise to the original data x0Dsubscript𝑥0superscript𝐷x_{0}\in\mathbb{R}^{D}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, and x0P0similar-tosubscript𝑥0subscript𝑃0x_{0}\sim P_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Let xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT denote the noisy data at time stamp t𝑡titalic_t, with marginal distribution and destiny as Ptsubscript𝑃𝑡P_{t}italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The conditional distribution P(xt|x0)𝑃conditionalsubscript𝑥𝑡subscript𝑥0P(x_{t}|x_{0})italic_P ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) follows N(β(t)x0,σ(t)ID)𝑁𝛽𝑡subscript𝑥0𝜎𝑡subscript𝐼𝐷N(\beta(t)x_{0},\sigma(t)I_{D})italic_N ( italic_β ( italic_t ) italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ ( italic_t ) italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ), where β(t)=exp(0tw(s)ds/2)𝛽𝑡superscriptsubscript0𝑡𝑤𝑠differential-d𝑠2\beta(t)={\exp}(-\int_{0}^{t}w(s)\mathrm{d}s/2)italic_β ( italic_t ) = roman_exp ( - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_w ( italic_s ) roman_d italic_s / 2 ), σ(t)=1β2(t)𝜎𝑡1superscript𝛽2𝑡\sigma(t)=1-\beta^{2}(t)italic_σ ( italic_t ) = 1 - italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ), and w(t)>0𝑤𝑡0w(t)>0italic_w ( italic_t ) > 0 is a nondecreasing weighting function. In practice, the forward process terminates at a large enough T𝑇Titalic_T such that PTsubscript𝑃𝑇P_{T}italic_P start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is close to N(0,ID)𝑁0subscript𝐼𝐷N(0,I_{D})italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ). In the backward process, we obtain ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by reversing the forward process. The generation of ytsubscript𝑦𝑡y_{t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT depends on the score function logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ). However, this is unknown in practice, we use a score estimator sW(,t)subscript𝑠𝑊𝑡s_{W}(\cdot,t)italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , italic_t ) to replace logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ), where sW(,t)subscript𝑠𝑊𝑡s_{W}(\cdot,t)italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , italic_t ) is usually a neural network with parameters W𝑊Witalic_W. See Section D.1 for the details.

Score Matching.

To estimate the score function, we use the following loss

minWT0Tγ(t)𝔼xtPt[sW(xt,t)logpt(xt)22]dt,subscript𝑊superscriptsubscriptsubscript𝑇0𝑇𝛾𝑡subscript𝔼similar-tosubscript𝑥𝑡subscript𝑃𝑡delimited-[]superscriptsubscriptnormsubscript𝑠𝑊subscript𝑥𝑡𝑡subscript𝑝𝑡subscript𝑥𝑡22𝑡\displaystyle\min_{W}\int_{T_{0}}^{T}\gamma(t)\mathbb{E}_{x_{t}\sim P_{t}}% \left[\norm{s_{W}(x_{t},t)-\nabla\log p_{t}(x_{t})}_{2}^{2}\right]% \differential t,roman_min start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_γ ( italic_t ) blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ start_ARG italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] start_DIFFOP roman_d end_DIFFOP italic_t ,

where γ(t)𝛾𝑡\gamma(t)italic_γ ( italic_t ) is the weight function, and T0subscript𝑇0T_{0}italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is a small value to stabilize training and prevent score function from blowing up (Vahdat et al., 2021). However, it is hard to compute logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) with available data samples. Therefore, we minimize the equivalent denosing score matching objective

minWT0Tγ(t)𝔼x0P0[𝔼xt|x0[sW(xt,t)xtlogψt(xtx0)22]]dt,\displaystyle\min_{W}\int_{T_{0}}^{T}\gamma(t)\mathbb{E}_{x_{0}\sim P_{0}}% \left[\mathbb{E}_{x_{t}|x_{0}}\left[\left\|s_{W}(x_{t},t)-\nabla_{x_{t}}\log% \psi_{t}(x_{t}\mid x_{0})\right\|_{2}^{2}\right]\right]\differential t,roman_min start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_γ ( italic_t ) blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ] start_DIFFOP roman_d end_DIFFOP italic_t , (2.1)

where ψt(xt|x0)subscript𝜓𝑡conditionalsubscript𝑥𝑡subscript𝑥0\psi_{t}(x_{t}|x_{0})italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is the transition kernel, then xtlogψt(xt|x0)=(β(t)x0xt)/σ(t)subscriptsubscript𝑥𝑡subscript𝜓𝑡conditionalsubscript𝑥𝑡subscript𝑥0𝛽𝑡subscript𝑥0subscript𝑥𝑡𝜎𝑡\nabla_{x_{t}}\log\psi_{t}(x_{t}|x_{0})=\left(\beta(t)x_{0}-x_{t}\right)/% \sigma(t)∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = ( italic_β ( italic_t ) italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) / italic_σ ( italic_t ).

To train the parameters W𝑊Witalic_W in the score estimator sW(,t)subscript𝑠𝑊𝑡s_{W}(\cdot,t)italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , italic_t ), we use the empirical version of (2.1). We select n𝑛nitalic_n i.i.d. data samples {x0,i}i=1nP0similar-tosuperscriptsubscriptsubscript𝑥0𝑖𝑖1𝑛subscript𝑃0\{x_{0,i}\}_{i=1}^{n}\sim P_{0}{ italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, and sample time tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (1in)1𝑖𝑛(1\leq i\leq n)( 1 ≤ italic_i ≤ italic_n ) uniformly from interval [T0,T]subscript𝑇0𝑇[T_{0},T][ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ]. Given x0,isubscript𝑥0𝑖x_{0,i}italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT, we sample xtisubscript𝑥subscript𝑡𝑖x_{t_{i}}italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT from N(β(ti)x0,i,σ(ti)ID)𝑁𝛽subscript𝑡𝑖subscript𝑥0𝑖𝜎subscript𝑡𝑖subscript𝐼𝐷N(\beta(t_{i})x_{0,i},\sigma(t_{i})I_{D})italic_N ( italic_β ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT , italic_σ ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ). The empirical loss is

minW^(W)subscript𝑊^𝑊\displaystyle\min_{W}\leavevmode\nobreak\ \widehat{\mathcal{L}}(W)roman_min start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG ( italic_W ) =1ni=1nsW(xti,ti)x0,i22.absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscriptnormsubscript𝑠𝑊subscript𝑥subscript𝑡𝑖subscript𝑡𝑖subscript𝑥0𝑖22\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\norm{s_{W}(x_{t_{i}},t_{i})-x_{0,i}}_{% 2}^{2}.= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ start_ARG italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (2.2)

For convenience of notation, we denote population loss (W)=𝔼P0[^(W)]𝑊subscript𝔼subscript𝑃0delimited-[]^𝑊\mathcal{L}(W)=\mathbb{E}_{P_{0}}[\widehat{\mathcal{L}}(W)]caligraphic_L ( italic_W ) = blackboard_E start_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over^ start_ARG caligraphic_L end_ARG ( italic_W ) ].

2.2 Score Decomposition in Linear Latent Space

In this part, we review the score decomposition in (Chen et al., 2023a). We consider that the D𝐷Ditalic_D-dimensional input data x𝑥xitalic_x supported on a d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-dimensional subspace, where d0Dsubscript𝑑0𝐷d_{0}\leq Ditalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≤ italic_D.

Assumption 2.1 (Low-Dimensional Linear Latent Space).

Data point x𝑥xitalic_x can be written as x=Bh𝑥𝐵x=Bhitalic_x = italic_B italic_h, where BD×d0𝐵superscript𝐷subscript𝑑0B\in\mathbb{R}^{D\times d_{0}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is an unknown matrix with orthonormal columns. The latent variable hd0superscriptsubscript𝑑0h\in\mathbb{R}^{d_{0}}italic_h ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT follows the distribution Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT with a density function phsubscript𝑝p_{h}italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT.

Remark 2.1.

By “Linear Latent Space,” we mean that each entry of a given latent vector is a linear combination of the corresponding input, i.e., h=Bx𝐵𝑥h=Bxitalic_h = italic_B italic_x. This is also knonw as “low-dimensional data” assumption in literature (Chen et al., 2023a).

Based on the low-dimensional data structure assumption, we have the following score decomposition theory: on-support score s+(Bx,t)subscript𝑠superscript𝐵top𝑥𝑡s_{+}(B^{\top}x,t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) and orthogonal score s(x,t)subscript𝑠𝑥𝑡s_{-}(x,t)italic_s start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ( italic_x , italic_t ).

Lemma 2.1 (Score Decomposition, Lemma 1 of (Chen et al., 2023a)).

Let data x=Bh𝑥𝐵x=Bhitalic_x = italic_B italic_h follow Assumption 2.1. The decomposition of score function logpt(x)subscript𝑝𝑡𝑥\nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) is

logpt(x)=Blogpth(h¯)s+(h¯,t)(IDBB)x/σ(t)s(x,t),h¯=Bx,formulae-sequencesubscript𝑝𝑡𝑥subscript𝐵superscriptsubscript𝑝𝑡¯subscript𝑠¯𝑡subscriptsubscript𝐼𝐷𝐵superscript𝐵top𝑥𝜎𝑡subscript𝑠𝑥𝑡¯superscript𝐵top𝑥\displaystyle\leavevmode\nobreak\ \nabla\log p_{t}(x)=\underbrace{B\nabla\log p% _{t}^{h}(\bar{h})}_{s_{+}(\bar{h},t)}\underbrace{-\left(I_{D}-BB^{\top}\right)% x/\sigma(t)}_{s_{-}(x,t)},\leavevmode\nobreak\ \bar{h}=B^{\top}x,∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = under⏟ start_ARG italic_B ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( over¯ start_ARG italic_h end_ARG ) end_ARG start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) end_POSTSUBSCRIPT under⏟ start_ARG - ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x / italic_σ ( italic_t ) end_ARG start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ( italic_x , italic_t ) end_POSTSUBSCRIPT , over¯ start_ARG italic_h end_ARG = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , (2.3)

where pth(h¯)ψt(h¯|h)ph(h)dhsuperscriptsubscript𝑝𝑡¯subscript𝜓𝑡conditional¯subscript𝑝p_{t}^{h}(\bar{h})\coloneqq\int\psi_{t}(\bar{h}|h)p_{h}(h)\differential hitalic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( over¯ start_ARG italic_h end_ARG ) ≔ ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG | italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h, ψt(|h)\psi_{t}(\cdot|h)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ | italic_h ) is the Gaussian density function of N(β(t)h,σ(t)Id0)𝑁𝛽𝑡𝜎𝑡subscript𝐼subscript𝑑0N(\beta(t)h,\sigma(t)I_{d_{0}})italic_N ( italic_β ( italic_t ) italic_h , italic_σ ( italic_t ) italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ), β(t)=et/2𝛽𝑡superscript𝑒𝑡2\beta(t)=e^{-t/2}italic_β ( italic_t ) = italic_e start_POSTSUPERSCRIPT - italic_t / 2 end_POSTSUPERSCRIPT and σ(t)=1et𝜎𝑡1superscript𝑒𝑡\sigma(t)=1-e^{-t}italic_σ ( italic_t ) = 1 - italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT. We restate the proof in Section D.2 for completeness.

Additionally, our theoretical analysis is based on two following assumptions as in (Chen et al., 2023a).

Assumption 2.2 (Tail Behavior of Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT).

The density function ph>0subscript𝑝0p_{h}>0italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT > 0 is twice continuously differentiable. Moreover, there exist positive constants A0,A1,A2subscript𝐴0subscript𝐴1subscript𝐴2A_{0},A_{1},A_{2}italic_A start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT such that when h2A0subscriptnorm2subscript𝐴0\norm{h}_{2}\geq A_{0}∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ italic_A start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, the density function ph(h)(2π)d0/2A1exp(A2h22/2)subscript𝑝superscript2𝜋subscript𝑑02subscript𝐴1subscript𝐴2superscriptsubscriptnorm222p_{h}(h)\leq(2\pi)^{-d_{0}/2}A_{1}{\exp}(-A_{2}\|h\|_{2}^{2}/2)italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) ≤ ( 2 italic_π ) start_POSTSUPERSCRIPT - italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT roman_exp ( - italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ italic_h ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 ).

Assumption 2.3 (Ls+subscript𝐿subscript𝑠L_{s_{+}}italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT-Lipschitz of s+(h¯,t)subscript𝑠¯𝑡s_{+}(\bar{h},t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t )).

The on-support score function s+(h¯,t)subscript𝑠¯𝑡s_{+}(\bar{h},t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) is Ls+subscript𝐿subscript𝑠L_{s_{+}}italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT-Lipschitz in h¯d0¯superscriptsubscript𝑑0\bar{h}\in\mathbb{R}^{d_{0}}over¯ start_ARG italic_h end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for any t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ].

2.3 Score Network and Transformers

In this part, we introduce the score network architecture and Transformers. Transformers are the backbone of the score network in DiT. By Assumption 2.1, h¯=Bxd0¯superscript𝐵top𝑥superscriptsubscript𝑑0\bar{h}=B^{\top}x\in\mathbb{R}^{d_{0}}over¯ start_ARG italic_h end_ARG = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with d0<Dsubscript𝑑0𝐷d_{0}<Ditalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < italic_D.

(Latent) Score Network.

Following (Chen et al., 2023a), we rearrange (2.3) into

logpt(x)=B(σ(t)logpth(Bx)+Bxq(Bx,t):d0×[T0,T]d0)/σ(t)x/σ(t).subscript𝑝𝑡𝑥𝐵subscript𝜎𝑡superscriptsubscript𝑝𝑡superscript𝐵top𝑥superscript𝐵top𝑥absent𝑞superscript𝐵top𝑥𝑡:absentsuperscriptsubscript𝑑0subscript𝑇0𝑇absentsuperscriptsubscript𝑑0𝜎𝑡𝑥𝜎𝑡\displaystyle\nabla\log p_{t}(x)=B(\underbrace{\sigma(t)\nabla\log p_{t}^{h}(B% ^{\top}x)+B^{\top}x}_{\coloneqq q(B^{\top}x,t):\mathbb{R}^{d_{0}}\times[T_{0},% T]\to\mathbb{R}^{d_{0}}})/\sigma(t)-x/\sigma(t).∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = italic_B ( under⏟ start_ARG italic_σ ( italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) + italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x end_ARG start_POSTSUBSCRIPT ≔ italic_q ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) / italic_σ ( italic_t ) - italic_x / italic_σ ( italic_t ) . (2.4)

We use WBD×d0subscript𝑊𝐵superscript𝐷subscript𝑑0W_{B}\in\mathbb{R}^{D\times d_{0}}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT to approximate BD×d0𝐵superscript𝐷subscript𝑑0B\in\mathbb{R}^{D\times d_{0}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and a neural network f(WBx,t)𝑓superscriptsubscript𝑊𝐵top𝑥𝑡f(W_{B}^{\top}x,t)italic_f ( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) to approximate q(Bx,t)𝑞superscript𝐵top𝑥𝑡q(B^{\top}x,t)italic_q ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ). We adopt the following score network class for diffusion in latent space (i.e., in hd0superscriptsubscript𝑑0h\in\mathbb{R}^{d_{0}}italic_h ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT)

𝒮={sW(x,t)=WBf(WBTx,t)/σ(t)x/σ(t),W={WB,f}},𝒮formulae-sequencesubscript𝑠𝑊𝑥𝑡subscript𝑊𝐵𝑓superscriptsubscript𝑊𝐵𝑇𝑥𝑡𝜎𝑡𝑥𝜎𝑡𝑊subscript𝑊𝐵𝑓\displaystyle\mathcal{S}=\left\{s_{W}(x,t)=W_{B}f(W_{B}^{T}x,t)/\sigma(t)-x/% \sigma(t),\leavevmode\nobreak\ W=\{W_{B},f\}\right\},caligraphic_S = { italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_x , italic_t ) = italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_f ( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x , italic_t ) / italic_σ ( italic_t ) - italic_x / italic_σ ( italic_t ) , italic_W = { italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_f } } , (2.5)

where the columns in WBsubscript𝑊𝐵W_{B}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT are orthogonal, f:d0×[T0,T]d0:𝑓superscriptsubscript𝑑0subscript𝑇0𝑇superscriptsubscript𝑑0f:\mathbb{R}^{d_{0}}\times[T_{0},T]\rightarrow\mathbb{R}^{d_{0}}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a neural network. In our work, we focus on the diffusion transformers (DiTs), i.e., using Transformer for f𝑓fitalic_f (Peebles and Xie, 2023).

Transformers.

A Transformer block consists of a self-attention layer and a feed-forward layer, with both layers having skip connection. We use τr,m,l:d×Ld×L:superscript𝜏𝑟𝑚𝑙superscript𝑑𝐿superscript𝑑𝐿\tau^{r,m,l}:\mathbb{R}^{d\times L}\rightarrow\mathbb{R}^{d\times L}italic_τ start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT to denote a Transformer block. Here r𝑟ritalic_r and m𝑚mitalic_m are the number of heads and head size in self-attention layer, and l𝑙litalic_l is the hidden dimension in feed-forward layer. Let Xd×L𝑋superscript𝑑𝐿X\in\mathbb{R}^{d\times L}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT be the model input, then we have the model output

Attn(X)Attn𝑋\displaystyle\leavevmode\nobreak\ {\rm Attn}(X)roman_Attn ( italic_X ) =X+i=1rWOiWViXSoftmax((WKiX)𝖳WQiX),absent𝑋superscriptsubscript𝑖1𝑟superscriptsubscript𝑊𝑂𝑖superscriptsubscript𝑊𝑉𝑖𝑋Softmaxsuperscriptsuperscriptsubscript𝑊𝐾𝑖𝑋𝖳superscriptsubscript𝑊𝑄𝑖𝑋\displaystyle=X+\sum\nolimits_{i=1}^{r}W_{O}^{i}W_{V}^{i}X\cdot\mathop{\rm{% Softmax}}\left(\left(W_{K}^{i}X\right)^{\mathsf{T}}W_{Q}^{i}X\right),= italic_X + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_X ⋅ roman_Softmax ( ( italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_X ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_X ) , (2.6)
FFAttn(X)FFAttn𝑋\displaystyle\leavevmode\nobreak\ {\rm FF}\circ{\rm Attn}(X)roman_FF ∘ roman_Attn ( italic_X ) =Attn(X)+W2ReLU(W1Attn(X)+b1𝟙L𝖳)+b2𝟙L𝖳,absentAttn𝑋subscript𝑊2ReLUsubscript𝑊1Attn𝑋subscript𝑏1superscriptsubscript1𝐿𝖳subscript𝑏2superscriptsubscript1𝐿𝖳\displaystyle={\rm Attn}(X)+W_{2}\cdot{\rm ReLU}(W_{1}\cdot{\rm Attn}(X)+b_{1}% \mathds{1}_{L}^{\mathsf{T}})+b_{2}\mathds{1}_{L}^{\mathsf{T}},= roman_Attn ( italic_X ) + italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⋅ roman_ReLU ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋅ roman_Attn ( italic_X ) + italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ) + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT , (2.7)

where WKi,WQi,WVim×d,WOid×m,W1l×d,W2d×l,b1l,b2dformulae-sequencesuperscriptsubscript𝑊𝐾𝑖superscriptsubscript𝑊𝑄𝑖superscriptsubscript𝑊𝑉𝑖superscript𝑚𝑑formulae-sequencesuperscriptsubscript𝑊𝑂𝑖superscript𝑑𝑚formulae-sequencesubscript𝑊1superscript𝑙𝑑formulae-sequencesubscript𝑊2superscript𝑑𝑙formulae-sequencesubscript𝑏1superscript𝑙subscript𝑏2superscript𝑑W_{K}^{i},W_{Q}^{i},W_{V}^{i}\in\mathbb{R}^{m\times d},W_{O}^{i}\in\mathbb{R}^% {d\times m},W_{1}\in\mathbb{R}^{l\times d},W_{2}\in\mathbb{R}^{d\times l},b_{1% }\in\mathbb{R}^{l},b_{2}\in\mathbb{R}^{d}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_m end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_l × italic_d end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_l end_POSTSUPERSCRIPT , italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

In our work, we use Transformer networks with positional encoding Ed×L𝐸superscript𝑑𝐿E\in\mathbb{R}^{d\times L}italic_E ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. We define the Transformer networks as the composition of Transformer blocks

𝒯Pr,m,l={f𝒯:d×Ld×Lf𝒯 is a composition of blocks τr,m,l’s}.superscriptsubscript𝒯𝑃𝑟𝑚𝑙conditional-setsubscript𝑓𝒯superscript𝑑𝐿conditionalsuperscript𝑑𝐿subscript𝑓𝒯 is a composition of blocks superscript𝜏𝑟𝑚𝑙’s\displaystyle\mathcal{T}_{P}^{r,m,l}=\{f_{\mathcal{T}}:\mathbb{R}^{d\times L}% \rightarrow{\mathbb{R}^{d\times L}}\mid f_{\mathcal{T}}\text{ is a composition% of blocks }\tau^{r,m,l}\text{'s}\}.caligraphic_T start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT = { italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT ∣ italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT is a composition of blocks italic_τ start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT ’s } .

For example, the following is a Transformer network consisting K𝐾Kitalic_K blocks and positional encoding

f𝒯(X)=FF(K)Attn(K)FF(1)Attn(1)(X+E).subscript𝑓𝒯𝑋superscriptFF𝐾superscriptAttn𝐾superscriptFF1superscriptAttn1𝑋𝐸\displaystyle f_{\mathcal{T}}(X)={\rm FF}^{(K)}\circ{\rm Attn}^{(K)}\circ% \cdots{\rm FF}^{(1)}\circ{\rm Attn}^{(1)}(X+E).italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_X ) = roman_FF start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∘ roman_Attn start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∘ ⋯ roman_FF start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ∘ roman_Attn start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ( italic_X + italic_E ) . (2.8)

3 Statistical Rates of Latent DiTs with Subspace Data Assumption

In this section, we analyze the statistical rates of latent DiTs. Section 3.1 introduces the class of latent DiT score networks. In Section 3.2, we prove the approximation limit of matching the DiT score function with the score network class, and characterize the structural configuration of the score network when a specified approximation error is required. Following this, in Section 3.3, utilizing the characterized structural configuration, we prove the score and distribution estimation for latent DiTs.

3.1 DiT Score Network Class

In this part, we give the details about DiT score network class used in our analysis. In (2.5), f𝑓fitalic_f is a network with Transformer as the backbone, and (h,t)d0×[T0,T]𝑡superscriptsubscript𝑑0subscript𝑇0𝑇(h,t)\in\mathbb{R}^{d_{0}}\times[T_{0},T]( italic_h , italic_t ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] denotes the input data. Following (Peebles and Xie, 2023), DiT uses time point t𝑡titalic_t to calculate the scale and shift value in the Transformer backbone, and it transforms a input picture into a sequential version. To achieve the transformation, we introduce a reshape layer.

Definition 3.1 (DiT Reshape Layer R()𝑅R(\cdot)italic_R ( ⋅ )).

Let R():d0d×L:𝑅superscriptsubscript𝑑0superscript𝑑𝐿R(\cdot):\mathbb{R}^{d_{0}}\to\mathbb{R}^{d\times L}italic_R ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT be a reshape layer that transforms the d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-dimensional input into a d×L𝑑𝐿d\times Litalic_d × italic_L matrix. Specifically, for any d0=i×isubscript𝑑0𝑖𝑖d_{0}=i\times iitalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_i × italic_i image input, R()𝑅R(\cdot)italic_R ( ⋅ ) converts it into a sequence representation with feature dimension dp2𝑑superscript𝑝2d\coloneqq p^{2}italic_d ≔ italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (where p2𝑝2p\geq 2italic_p ≥ 2) and sequence length L(i/p)2𝐿superscript𝑖𝑝2L\coloneqq\left(i/p\right)^{2}italic_L ≔ ( italic_i / italic_p ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Besides, we define the corresponding reverse reshape (flatten) layer R1():d×Ld0:superscript𝑅1superscript𝑑𝐿superscriptsubscript𝑑0R^{-1}(\cdot):\mathbb{R}^{d\times L}\to\mathbb{R}^{d_{0}}italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT as the inverse of R()𝑅R(\cdot)italic_R ( ⋅ ). By d0=dLsubscript𝑑0𝑑𝐿d_{0}=dLitalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d italic_L, R,R1𝑅superscript𝑅1R,R^{-1}italic_R , italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT are associative w.r.t. their input.

To simplify the self-attention block in (2.6), let WOVi=WOiWVisuperscriptsubscript𝑊𝑂𝑉𝑖superscriptsubscript𝑊𝑂𝑖superscriptsubscript𝑊𝑉𝑖W_{OV}^{i}=W_{O}^{i}W_{V}^{i}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT and WKQi=(WKi)𝖳WQisuperscriptsubscript𝑊𝐾𝑄𝑖superscriptsuperscriptsubscript𝑊𝐾𝑖𝖳superscriptsubscript𝑊𝑄𝑖W_{KQ}^{i}=(W_{K}^{i})^{\mathsf{T}}W_{Q}^{i}italic_W start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = ( italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT.

Definition 3.2 (Transformer Network Class 𝒯pr,m,lsuperscriptsubscript𝒯𝑝𝑟𝑚𝑙\mathcal{T}_{p}^{r,m,l}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT).

We define the Transformer network class as

𝒯pr,m,lsuperscriptsubscript𝒯𝑝𝑟𝑚𝑙\displaystyle\mathcal{T}_{p}^{r,m,l}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT (K,C𝒯,COV2,,COV,CKQ2,,CKQ,CF2,,CF,CE,L𝒯),satisfying the constraints𝐾subscript𝐶𝒯superscriptsubscript𝐶𝑂𝑉2subscript𝐶𝑂𝑉superscriptsubscript𝐶𝐾𝑄2subscript𝐶𝐾𝑄superscriptsubscript𝐶𝐹2subscript𝐶𝐹subscript𝐶𝐸subscript𝐿𝒯satisfying the constraints\displaystyle(K,C_{\mathcal{T}},C_{OV}^{2,\infty},C_{OV},C_{KQ}^{2,\infty},C_{% KQ},C_{F}^{2,\infty},C_{F},C_{E},L_{\mathcal{T}}),\leavevmode\nobreak\ \text{% satisfying the constraints}( italic_K , italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) , satisfying the constraints
  • Model architecture with K𝐾Kitalic_K blocks: f𝒯(X)=FF(K)Attn(K)FF(1)Attn(1)(X)subscript𝑓𝒯𝑋superscriptFF𝐾superscriptAttn𝐾superscriptFF1superscriptAttn1𝑋f_{\mathcal{T}}(X)={\rm FF}^{(K)}\circ{\rm Attn}^{(K)}\circ\cdots{\rm FF}^{(1)% }\circ{\rm Attn}^{(1)}(X)italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_X ) = roman_FF start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∘ roman_Attn start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT ∘ ⋯ roman_FF start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ∘ roman_Attn start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ( italic_X );

  • Model output bound: supXf𝒯(X)2C𝒯subscriptsupremum𝑋subscriptnormsubscript𝑓𝒯𝑋2subscript𝐶𝒯\sup_{X}\norm{f_{\mathcal{T}}(X)}_{2}\leq C_{\mathcal{T}}roman_sup start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ∥ start_ARG italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_X ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT;

  • Parameter bound in Attn(i)superscriptAttni{\rm Attn^{(i)}}roman_Attn start_POSTSUPERSCRIPT ( roman_i ) end_POSTSUPERSCRIPT: (WOVi)2,COV2,subscriptnormsuperscriptsuperscriptsubscript𝑊𝑂𝑉𝑖top2superscriptsubscript𝐶𝑂𝑉2\norm{(W_{OV}^{i})^{\top}}_{2,\infty}\leq C_{OV}^{2,\infty}∥ start_ARG ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT, (WOVi)2COVsubscriptnormsuperscriptsuperscriptsubscript𝑊𝑂𝑉𝑖top2subscript𝐶𝑂𝑉\norm{(W_{OV}^{i})^{\top}}_{2}\leq C_{OV}∥ start_ARG ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT, WKQi2,CKQ2,subscriptnormsuperscriptsubscript𝑊𝐾𝑄𝑖2superscriptsubscript𝐶𝐾𝑄2\norm{W_{KQ}^{i}}_{2,\infty}\leq C_{KQ}^{2,\infty}∥ start_ARG italic_W start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT, WKQi2CKQsubscriptnormsuperscriptsubscript𝑊𝐾𝑄𝑖2subscript𝐶𝐾𝑄\norm{W_{KQ}^{i}}_{2}\leq C_{KQ}∥ start_ARG italic_W start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT, E2,CE,i[K]formulae-sequencesubscriptnormsuperscript𝐸top2subscript𝐶𝐸for-all𝑖delimited-[]𝐾\norm{E^{\top}}_{2,\infty}\leq C_{E},\forall i\in[K]∥ start_ARG italic_E start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT , ∀ italic_i ∈ [ italic_K ];

  • Parameter bound in FF(i)superscriptFFi{\rm FF^{(i)}}roman_FF start_POSTSUPERSCRIPT ( roman_i ) end_POSTSUPERSCRIPT: Wji2,CF2,,Wji2CF,j[2],i[K]formulae-sequencesubscriptnormsuperscriptsubscript𝑊𝑗𝑖2superscriptsubscript𝐶𝐹2formulae-sequencesubscriptnormsuperscriptsubscript𝑊𝑗𝑖2subscript𝐶𝐹formulae-sequencefor-all𝑗delimited-[]2𝑖delimited-[]𝐾\norm{W_{j}^{i}}_{2,\infty}\leq C_{F}^{2,\infty},\norm{W_{j}^{i}}_{2}\leq C_{F% },\forall j\in[2],i\in[K]∥ start_ARG italic_W start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT , ∀ italic_j ∈ [ 2 ] , italic_i ∈ [ italic_K ];

  • Lipschitz of f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT: f𝒯(X1)f𝒯(X2)FL𝒯X1X2F,X1,X2d×Lformulae-sequencesubscriptnormsubscript𝑓𝒯subscript𝑋1subscript𝑓𝒯subscript𝑋2𝐹subscript𝐿𝒯subscriptnormsubscript𝑋1subscript𝑋2𝐹for-allsubscript𝑋1subscript𝑋2superscript𝑑𝐿\norm{f_{\mathcal{T}}(X_{1})-f_{\mathcal{T}}(X_{2})}_{F}\leq L_{\mathcal{T}}% \norm{X_{1}-X_{2}}_{F},\forall X_{1},X_{2}\in\mathbb{R}^{d\times L}∥ start_ARG italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_ARG italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT , ∀ italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT.

Definition 3.3 (DiT Score Network Class 𝒮𝒯pr,m,lsubscript𝒮superscriptsubscript𝒯𝑝𝑟𝑚𝑙\mathcal{S}_{\mathcal{T}_{p}^{r,m,l}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT).

We denote 𝒮𝒯pr,m,lsubscript𝒮superscriptsubscript𝒯𝑝𝑟𝑚𝑙\mathcal{S}_{\mathcal{T}_{p}^{r,m,l}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT as the DiT score network class in (2.5), replacing f𝑓fitalic_f with R1f𝒯Rsuperscript𝑅1subscript𝑓𝒯𝑅{R^{-1}\circ f_{\mathcal{T}}\circ R}italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∘ italic_R, and f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT is from the Transformer class 𝒯pr,m,lsuperscriptsubscript𝒯𝑝𝑟𝑚𝑙\mathcal{T}_{p}^{r,m,l}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT.

3.2 Score Approximation of DiT

Here, we explore the approximation limit of latent DiT score network class 𝒮𝒯pr,m,lsubscript𝒮superscriptsubscript𝒯𝑝𝑟𝑚𝑙\mathcal{S}_{\mathcal{T}_{p}^{r,m,l}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under linear latent space assumption. Recall that Ptsubscript𝑃𝑡P_{t}italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the distribution of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, σ(t)𝜎𝑡\sigma(t)italic_σ ( italic_t ) is the variance of P(xt|x0)𝑃conditionalsubscript𝑥𝑡subscript𝑥0P(x_{t}|x_{0})italic_P ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the dimension of latent space, L𝐿Litalic_L is the sequence length of transformer input, T𝑇Titalic_T is the stop** time in forward process, T0subscript𝑇0T_{0}italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the early stop** time in backward process, and Ls+subscript𝐿subscript𝑠L_{s_{+}}italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the Lipschitz coefficient of on-support score function. Then we have the following Theorem 3.1.

Theorem 3.1 (Score Approximation of DiT).

For any approximation error ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0 and any data distribution P0subscript𝑃0P_{0}italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT under Assumptions 2.1, 2.2 and 2.3, there exists a DiT score network sW^subscript𝑠^𝑊s_{\widehat{W}}italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT from 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (defined in Definition 3.2), where W^={W^B,f^𝒯}^𝑊subscript^𝑊𝐵subscript^𝑓𝒯\widehat{W}=\{\widehat{W}_{B},\widehat{f}_{\mathcal{T}}\}over^ start_ARG italic_W end_ARG = { over^ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT }, such that for any t[T0,T]𝑡subscript𝑇0𝑇t\in[T_{0},T]italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ], we have:

sW^(,t)logpt()L2(Pt)ϵd0/σ(t),subscriptnormsubscript𝑠^𝑊𝑡subscript𝑝𝑡superscript𝐿2subscript𝑃𝑡italic-ϵsubscript𝑑0𝜎𝑡\displaystyle\norm{s_{\widehat{W}}(\cdot,t)-\nabla\log p_{t}(\cdot)}_{L^{2}(P_% {t})}\leq\epsilon\cdot\sqrt{d_{0}}/\sigma(t),∥ start_ARG italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ( ⋅ , italic_t ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ≤ italic_ϵ ⋅ square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG / italic_σ ( italic_t ) ,

where σ(t)=1et𝜎𝑡1superscript𝑒𝑡\sigma(t)=1-e^{-t}italic_σ ( italic_t ) = 1 - italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT, and the upper bound of hyperparameters in 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT are

K=𝒪(ϵ2L),C𝒯=𝒪(d0Ls+d0log(d0/T0)+log(1/ϵ)),formulae-sequence𝐾𝒪superscriptitalic-ϵ2𝐿subscript𝐶𝒯𝒪subscript𝑑0subscript𝐿subscript𝑠subscript𝑑0subscript𝑑0subscript𝑇01italic-ϵ\displaystyle\leavevmode\nobreak\ K=\mathcal{O}(\epsilon^{-2L}),\leavevmode% \nobreak\ C_{\mathcal{T}}=\mathcal{O}\left(d_{0}L_{s_{+}}\sqrt{d_{0}\log(d_{0}% /T_{0})+\log(1/\epsilon)}\right),italic_K = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 2 italic_L end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT = caligraphic_O ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) + roman_log ( start_ARG 1 / italic_ϵ end_ARG ) end_ARG ) ,
COV2,=(1/ϵ)𝒪(1),COV=(1/ϵ)𝒪(1),CKQ2,=(1/ϵ)𝒪(1),CKQ=(1/ϵ)𝒪(1),formulae-sequencesuperscriptsubscript𝐶𝑂𝑉2superscript1italic-ϵ𝒪1formulae-sequencesubscript𝐶𝑂𝑉superscript1italic-ϵ𝒪1formulae-sequencesuperscriptsubscript𝐶𝐾𝑄2superscript1italic-ϵ𝒪1subscript𝐶𝐾𝑄superscript1italic-ϵ𝒪1\displaystyle\leavevmode\nobreak\ C_{OV}^{2,\infty}=(1/\epsilon)^{\mathcal{O}(% 1)},\leavevmode\nobreak\ C_{OV}=(1/\epsilon)^{\mathcal{O}(1)},\leavevmode% \nobreak\ C_{KQ}^{2,\infty}=(1/\epsilon)^{\mathcal{O}(1)},\leavevmode\nobreak% \ C_{KQ}=(1/\epsilon)^{\mathcal{O}(1)},italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT = ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT = ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT = ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT = ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT ,
CE=𝒪(L3/2),CF2,=(1/ϵ)𝒪(1),CF=(1/ϵ)𝒪(1),L𝒯=𝒪(d0Ls+).formulae-sequencesubscript𝐶𝐸𝒪superscript𝐿32formulae-sequencesuperscriptsubscript𝐶𝐹2superscript1italic-ϵ𝒪1formulae-sequencesubscript𝐶𝐹superscript1italic-ϵ𝒪1subscript𝐿𝒯𝒪subscript𝑑0subscript𝐿subscript𝑠\displaystyle\leavevmode\nobreak\ C_{E}=\mathcal{O}(L^{3/2}),\leavevmode% \nobreak\ C_{F}^{2,\infty}=(1/\epsilon)^{\mathcal{O}(1)},\leavevmode\nobreak\ % C_{F}=(1/\epsilon)^{\mathcal{O}(1)},\leavevmode\nobreak\ L_{\mathcal{T}}=% \mathcal{O}\left(d_{0}L_{s_{+}}\right).italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT = caligraphic_O ( italic_L start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT = ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT = ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT , italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT = caligraphic_O ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) .
Proof Sketch.

Our proof is built on the key observation that there is a tail behavior of the low-dimensional latent variable distribution Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT (Assumption 2.2). Recall that logpt(x)=Bq(h¯,t)/σ(t)x/σ(t)subscript𝑝𝑡𝑥𝐵𝑞¯𝑡𝜎𝑡𝑥𝜎𝑡\nabla\log p_{t}(x)=Bq(\bar{h},t)/\sigma(t)-x/\sigma(t)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = italic_B italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) / italic_σ ( italic_t ) - italic_x / italic_σ ( italic_t ), where h¯=Bx¯superscript𝐵top𝑥\bar{h}=B^{\top}xover¯ start_ARG italic_h end_ARG = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x (defined in (2.4)). By taking W^B=Bsubscript^𝑊𝐵𝐵\widehat{W}_{B}=Bover^ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT = italic_B, our aim reduces to construct a transformer network to approximate q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ). To achieve this, we firstly approximate q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) with a compact-supported continuous function, based on the tail behavior of Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT. Then we construct a transformer to approximate the compact-supported continuous function using the universal approximation capacity of transformer (Yun et al., 2020). See Section F.1 for a detailed proof. ∎

Intuitively, Theorem 3.1 indicates the capability of the transformer-based score network to approximate the score function with precise guarantees. Furthermore, Theorem 3.1 provides empirical guidance for the design choices of the score network when a specified approximation error is required.

Remark 3.1 (Comparing with Existing Works).

Theoretical analysis of DiTs is limited. Previous works that do not specify the model architecture assume that the score estimator is well-approximated (Benton et al., 2024; Wibisono et al., 2024). To the best of our knowledge, this work is the first to present an approximation theory for DiTs, offering the estimation theory in Corollaries 3.1.1 and 3.1.2 based on the estimated score network, rather than a perfectly trained one.

Remark 3.2 (Latent Dimension Dependency).

Theorem 3.1 suggests that the approximation capacity and Transformer network size primarily depend on the latent variable dimension d0=d×Lsubscript𝑑0𝑑𝐿d_{0}=d\times Litalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d × italic_L. This indicates that DiTs can potentially bypass the challenges associated with the high dimensionality of initial data by transforming input data into a low-dimensional latent variable.

3.3 Score Estimation and Distribution Estimation

Besides score approximation capability, Theorem 3.1 also characterizes the structural configuration of the score network for any specific precision, e.g., K,CE,CF𝐾subscript𝐶𝐸subscript𝐶𝐹K,C_{E},C_{F}italic_K , italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT, etc. This characterization enables further analysis of the performance of score network in practical scenarios. In Corollary 3.1.1, we provide an sample complexity bound for score estimation. In Corollary 3.1.2, show that the learned score estimator is able to recover the initial data distribution.

Score Estimation.

To derive a sample complexity for score estimation using 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, we rewrite the score matching objective in (2.2) as W^argminsW𝒮𝒯p2,1,4^(sW),W^={W^B,f^𝒯}formulae-sequence^𝑊subscriptargminsubscript𝑠𝑊subscript𝒮superscriptsubscript𝒯𝑝214^subscript𝑠𝑊^𝑊subscript^𝑊𝐵subscript^𝑓𝒯\widehat{W}\in\mathop{\mathrm{argmin}}_{s_{W}\in\mathcal{S}_{\mathcal{T}_{p}^{% 2,1,4}}}\widehat{\mathcal{L}}(s_{W}),\leavevmode\nobreak\ \widehat{W}=\{% \widehat{W}_{B},\widehat{f}_{\mathcal{T}}\}over^ start_ARG italic_W end_ARG ∈ roman_argmin start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ∈ caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG ( italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ) , over^ start_ARG italic_W end_ARG = { over^ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT }.

Corollary 3.1.1 shows that as sample size n𝑛n\rightarrow\inftyitalic_n → ∞, sW(,t)subscript𝑠𝑊𝑡s_{W}(\cdot,t)italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , italic_t ) convergences to logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ).

Corollary 3.1.1 (Score Estimation of DiT).

Under Assumptions 2.1, 2.2 and 2.3, we choose 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT as in Theorem 3.1 using ϵ(0,1)italic-ϵ01\epsilon\in(0,1)italic_ϵ ∈ ( 0 , 1 ) and L>1𝐿1L>1italic_L > 1, With probability 11/poly(n)11poly𝑛1-1/\mathrm{poly}(n)1 - 1 / roman_poly ( italic_n ), we have

1TT0T0TsW^(,t)logpt()L2(Pt)dt=𝒪~(1n1/2TT02(1/ϵ)2L+1T0Tϵ2+1n),1𝑇subscript𝑇0superscriptsubscriptsubscript𝑇0𝑇subscriptnormsubscript𝑠^𝑊𝑡subscript𝑝𝑡superscript𝐿2subscript𝑃𝑡𝑡~𝒪1superscript𝑛12𝑇subscript𝑇0superscript2superscript1italic-ϵ2𝐿1subscript𝑇0𝑇superscriptitalic-ϵ21𝑛\displaystyle\leavevmode\nobreak\ \frac{1}{T-T_{0}}\int_{T_{0}}^{T}\norm{s_{% \widehat{W}}(\cdot,t)-\nabla\log p_{t}(\cdot)}_{L^{2}(P_{t})}\differential t=% \widetilde{\mathcal{O}}\left(\frac{1}{n^{1/2}}\frac{T}{T_{0}}\cdot 2^{(1/% \epsilon)^{2L}}+\frac{1}{T_{0}T}\epsilon^{2}+\frac{1}{n}\right),divide start_ARG 1 end_ARG start_ARG italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ start_ARG italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ( ⋅ , italic_t ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_DIFFOP roman_d end_DIFFOP italic_t = over~ start_ARG caligraphic_O end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG italic_T end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ⋅ 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ) , (3.1)

where 𝒪~~𝒪\widetilde{\mathcal{O}}over~ start_ARG caligraphic_O end_ARG hides the factor about D,d0,d,Ls+,logn𝐷subscript𝑑0𝑑subscript𝐿subscript𝑠𝑛D,d_{0},d,L_{s_{+}},\log nitalic_D , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_d , italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_log italic_n.

Proof.

See Section F.2 for a detailed proof. ∎

Intuitively, Corollary 3.1.1 shows a sample complexity bound for score estimation in practice.

Remark 3.3 (Comparing with Existing Works).

(Zhu et al., 2023) provides a sample complexity for simple ReLU-based diffusion models under the assumption of an accurate score estimator. To the best of our knowledge, we are the first to provide a sample complexity for DiTs, based on the learned score network in Theorem 3.1 and the quantization (piece-wise approximation) approach for transformer universality (Yun et al., 2020).

Remark 3.4.

Corollary 3.1.1 reports an explicit result on sample complexity bounds for score estimation of latent DiTs: a double exponential factor 2(1/ϵ)2Lsuperscript2superscript1italic-ϵ2𝐿2^{(1/\epsilon)^{2L}}2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT in the first term. We remark that this arises from the required depth K𝐾Kitalic_K is 𝒪(ϵ2L)𝒪superscriptitalic-ϵ2𝐿\mathcal{O}(\epsilon^{-2L})caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 2 italic_L end_POSTSUPERSCRIPT ), and the norm of required weight parameters is (1/ϵ)𝒪(1)superscript1italic-ϵ𝒪1(1/\epsilon)^{\mathcal{O}(1)}( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT as shown in Theorem 3.1, assuming the universality of transformers requires dense layers (Yun et al., 2020). This motivate us to rethink about transformer universality and explore new proof techniques for DiTs, which we leave for future work.

Definition 3.4.

For later convenience, we define ξ(n,ϵ,L):=1n1/2TT02(1/ϵ)2L+1T0Tϵ2+1nassign𝜉𝑛italic-ϵ𝐿1superscript𝑛12𝑇subscript𝑇0superscript2superscript1italic-ϵ2𝐿1subscript𝑇0𝑇superscriptitalic-ϵ21𝑛\xi(n,\epsilon,L):=\frac{1}{n^{1/2}}\frac{T}{T_{0}}\cdot 2^{(1/\epsilon)^{2L}}% +\frac{1}{T_{0}T}\epsilon^{2}+\frac{1}{n}italic_ξ ( italic_n , italic_ϵ , italic_L ) := divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG italic_T end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ⋅ 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG.

Distribution Estimation.

In practice, DiTs generate data using the discretized version with step size μ𝜇\muitalic_μ, see Section D.1 for details. Let P^T0subscript^𝑃subscript𝑇0\widehat{P}_{T_{0}}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT be the distribution generated by sW^subscript𝑠^𝑊s_{\widehat{W}}italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT in Corollary 3.1.1. Let PT0hsuperscriptsubscript𝑃subscript𝑇0P_{T_{0}}^{h}italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT and pT0hsuperscriptsubscript𝑝subscript𝑇0p_{T_{0}}^{h}italic_p start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT be the distribution and density function of on-support latent variable h¯¯\bar{h}over¯ start_ARG italic_h end_ARG at T0subscript𝑇0T_{0}italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We have the following results for distribution estimation.

Corollary 3.1.2 (Distribution Estimation of DiT, Modified From Theorem 3 of (Chen et al., 2023a)).

Let T=𝒪(logn),T0=𝒪(min{c0,1/Ls+})formulae-sequence𝑇𝒪𝑛subscript𝑇0𝒪subscript𝑐01subscript𝐿subscript𝑠T=\mathcal{O}(\log n),T_{0}=\mathcal{O}(\min\{c_{0},1/L_{s_{+}}\})italic_T = caligraphic_O ( roman_log italic_n ) , italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = caligraphic_O ( roman_min { italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , 1 / italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT } ), where c0subscript𝑐0c_{0}italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the minimum eigenvalue of 𝔼Ph[hh]subscript𝔼subscript𝑃delimited-[]superscripttop\mathbb{E}_{P_{h}}[hh^{\top}]blackboard_E start_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_h italic_h start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]. With the estimated DiT score network sW^subscript𝑠^𝑊s_{\widehat{W}}italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT in Corollary 3.1.1, we have the following with probability 11/poly(n)11poly𝑛1-1/\mathrm{poly}(n)1 - 1 / roman_poly ( italic_n ).

  • (i)

    The accuracy to recover the subspace B𝐵Bitalic_B is WBWBBBF2=𝒪~(T0ξ(n,ϵ,L)/c0)superscriptsubscriptnormsubscript𝑊𝐵superscriptsubscript𝑊𝐵top𝐵superscript𝐵top𝐹2~𝒪subscript𝑇0𝜉𝑛italic-ϵ𝐿subscript𝑐0\norm{W_{B}W_{B}^{\top}-BB^{\top}}_{F}^{2}=\widetilde{\mathcal{O}}\left(T_{0}% \xi(n,\epsilon,L)/c_{0}\right)∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = over~ start_ARG caligraphic_O end_ARG ( italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_ξ ( italic_n , italic_ϵ , italic_L ) / italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ).

  • (ii)

    (WBU)P^T0subscriptsuperscriptsubscript𝑊𝐵𝑈topsubscript^𝑃subscript𝑇0(W_{B}U)^{\top}_{\sharp}\widehat{P}_{T_{0}}( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_U ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT denotes the pushforward distribution. With the conditions 𝖪𝖫(Ph||N(0,Id0))<{\sf KL}(P_{h}||N(0,I_{d_{0}}))<\inftysansserif_KL ( italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT | | italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ) < ∞, and step size μξ(n,ϵ,L)T02/(d0logd0)𝜇𝜉𝑛italic-ϵ𝐿superscriptsubscript𝑇02subscript𝑑0subscript𝑑0\mu\leq\xi(n,\epsilon,L)\cdot T_{0}^{2}/(d_{0}\sqrt{\log d_{0}})italic_μ ≤ italic_ξ ( italic_n , italic_ϵ , italic_L ) ⋅ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ). There exists an orthogonal matrix Ud×d𝑈superscript𝑑𝑑U\in\mathbb{R}^{d\times d}italic_U ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT such that we have the following upper bound for the total variation distance

    𝖳𝖵(PT0h,(WBU)P^T0)=𝒪~(ξ(n,ϵ,L)),𝖳𝖵superscriptsubscript𝑃subscript𝑇0subscriptsuperscriptsubscript𝑊𝐵𝑈topsubscript^𝑃subscript𝑇0~𝒪𝜉𝑛italic-ϵ𝐿\displaystyle{\sf TV}(P_{T_{0}}^{h},(W_{B}U)^{\top}_{\sharp}\widehat{P}_{T_{0}% })=\widetilde{\mathcal{O}}(\sqrt{\xi(n,\epsilon,L)}),sansserif_TV ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , ( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_U ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = over~ start_ARG caligraphic_O end_ARG ( square-root start_ARG italic_ξ ( italic_n , italic_ϵ , italic_L ) end_ARG ) , (3.2)

    where 𝒪~~𝒪\widetilde{\mathcal{O}}over~ start_ARG caligraphic_O end_ARG hides the factor about D,d0,d,Ls+,logn𝐷subscript𝑑0𝑑subscript𝐿subscript𝑠𝑛D,d_{0},d,L_{s_{+}},\log nitalic_D , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_d , italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_log italic_n, and TT0𝑇subscript𝑇0T-T_{0}italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT.

  • (iii)

    For the generated data distribution P^T0subscript^𝑃subscript𝑇0\widehat{P}_{T_{0}}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT, the orthogonal pushforward (IWBWB)P^T0subscript𝐼subscript𝑊𝐵superscriptsubscript𝑊𝐵topsubscript^𝑃subscript𝑇0(I-W_{B}W_{B}^{\top})_{\sharp}\widehat{P}_{T_{0}}( italic_I - italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT is N(0,Σ)𝑁0Σ{N}(0,\Sigma)italic_N ( 0 , roman_Σ ), where ΣaT0Iprecedes-or-equalsΣ𝑎subscript𝑇0𝐼\Sigma\preceq aT_{0}Iroman_Σ ⪯ italic_a italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_I for a constant a>0𝑎0a>0italic_a > 0.

Proof.

See Section F.3 for a detailed proof. ∎

Intuitively, Corollary 3.1.2 shows the estimation results including 3 parts: (i) The accuracy to recover the subspace B𝐵Bitalic_B. (ii) The estimation error between P^T0subscript^𝑃subscript𝑇0\widehat{P}_{T_{0}}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT and PT0hsuperscriptsubscript𝑃subscript𝑇0P_{T_{0}}^{h}italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT. (iii) The vanishing behavior of P^T0subscript^𝑃subscript𝑇0\widehat{P}_{T_{0}}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT in the orthogonal space. These three parts indicate that the learned score estimator is capable of recovering the initial data distribution. Notably, Corollary 3.1.2 is agnostic to details of ξ(n,ϵ,L)𝜉𝑛italic-ϵ𝐿\xi(n,\epsilon,L)italic_ξ ( italic_n , italic_ϵ , italic_L ).

Remark 3.5 (Comparing with Existing Works).

Oko et al. (2023) analyze the distribution estimation under the assumption that the initial density is supported on [1,1]Dsuperscript11𝐷[-1,1]^{D}[ - 1 , 1 ] start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and smooth in the boundary. Our Assumption 2.2 demonstrates greater practical relevance. This suggests that our method of distribution estimation aligns more closely with empirical realities.

Remark 3.6 (Subspace Recovery Accuracy).

(i) of Corollary 3.1.2 confirms that the subspace is learned by DiTs. The error is proportional to the sample complexity for score estimation and depend on the minimum eigenvalue of the covariance of Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT.

4 Provably Efficient Criteria

Here, we analyze the computational limits of latent DiTs under low-dimensional linear subspace data assumption (i.e., Assumption 2.1). The hardness of DiT models ties to both forward and backward passes of the score network in Definition 3.3. We characterize them separately.

4.1 Computational Limits of Backward Computation

Following Section 2, suppose we have n𝑛nitalic_n i.i.d. data samples {x0,i}i=1nPdsimilar-tosuperscriptsubscriptsubscript𝑥0𝑖𝑖1𝑛subscript𝑃𝑑\{x_{0,i}\}_{i=1}^{n}\sim P_{d}{ italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∼ italic_P start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, and time ti0subscript𝑡subscript𝑖0t_{i_{0}}italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT (1in)1𝑖𝑛(1\leq i\leq n)( 1 ≤ italic_i ≤ italic_n ) uniformly sampled from [T0,T]subscript𝑇0𝑇[T_{0},T][ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ]. For each data x0,iDsubscript𝑥0𝑖superscript𝐷x_{0,i}\in\mathbb{R}^{D}italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, we sample xti0Dsubscript𝑥subscript𝑡subscript𝑖0superscript𝐷x_{t_{i_{0}}}\in\mathbb{R}^{D}italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT from N(β(ti0)x0,i,σ(ti0)ID)𝑁𝛽subscript𝑡subscript𝑖0subscript𝑥0𝑖𝜎subscript𝑡subscript𝑖0subscript𝐼𝐷N(\beta(t_{i_{0}})x_{0,i},\sigma(t_{i_{0}})I_{D})italic_N ( italic_β ( italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT , italic_σ ( italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ). Let (WAR1())superscriptsubscript𝑊𝐴superscript𝑅1(W_{A}R^{-1}(\cdot))^{\dagger}( italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ⋅ ) ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT be the inverse transformation of WAR1()subscript𝑊𝐴superscript𝑅1W_{A}R^{-1}(\cdot)italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ⋅ ), and denote Y0,i(WAR1)(x0,i)d×Lsubscript𝑌0𝑖superscriptsubscript𝑊𝐴superscript𝑅1subscript𝑥0𝑖superscript𝑑𝐿Y_{0,i}\coloneqq(W_{A}R^{-1})^{\dagger}(x_{0,i})\in\mathbb{R}^{d\times L}italic_Y start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ≔ ( italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. We rewrite the empirical denoising score-matching loss (2.2) as

1ni=1nWAR1(f𝒯(R(WAxti0d0×1)))x0,iF2=1ni=1nWAD×d0R1(f𝒯(R(WAxti0)d0×1d×L)Y0,id×L)F2.\displaystyle\frac{1}{n}\sum_{i=1}^{n}\Big{\|}W_{A}R^{-1}(f_{\mathcal{T}}(R(% \underbrace{W_{A}^{\top}x_{t_{i_{0}}}}_{d_{0}\times 1})))-x_{0,i}\Big{\|}_{F}^% {2}=\frac{1}{n}\sum_{i=1}^{n}\Big{\|}\underbrace{W_{A}}_{D\times d_{0}}% \underbrace{R^{-1}\big{(}\overbrace{f_{\mathcal{T}}(R(W_{A}^{\top}x_{t_{i_{0}}% }})}_{d_{0}\times 1}^{d\times L})-\underbrace{Y_{0,i}}_{d\times L}\big{)}\Big{% \|}_{F}^{2}.divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_R ( under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT × 1 end_POSTSUBSCRIPT ) ) ) - italic_x start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( over⏞ start_ARG italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_R ( italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG ) end_ARG start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT × 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT ) - under⏟ start_ARG italic_Y start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (4.1)

For efficiency, it suffices to focus on just transformer attention heads of the DiT score network due to their dominating quadratic time complexity in both passes. Thus, we consider only a single layer attention for f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT, to simplify our analysis. Further, we consider the following simplifications:

  • (S0)

    To prove the hardness of (4.1) for both full full gradient descent and stochastic mini-batch gradient descent methods, it suffices to consider training on a single data point.

  • (S1)

    For the convenience of our analysis, we consider the following expression for attention mechanism. Let X,Yd×L𝑋𝑌superscript𝑑𝐿X,Y\in\mathbb{R}^{d\times L}italic_X , italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. Let WK,WQ,WVs×dsubscript𝑊𝐾subscript𝑊𝑄subscript𝑊𝑉superscript𝑠𝑑W_{K},W_{Q},W_{V}\in\mathbb{R}^{s\times d}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_d end_POSTSUPERSCRIPT be attention weights such that Q=WQXd×L𝑄subscript𝑊𝑄𝑋superscript𝑑𝐿Q=W_{Q}X\in\mathbb{R}^{d\times L}italic_Q = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, K=WKXs×L𝐾subscript𝑊𝐾𝑋superscript𝑠𝐿K=W_{K}X\in\mathbb{R}^{s\times L}italic_K = italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_L end_POSTSUPERSCRIPT and V=WVXs×L𝑉subscript𝑊𝑉𝑋superscript𝑠𝐿V=W_{V}X\in\mathbb{R}^{s\times L}italic_V = italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_L end_POSTSUPERSCRIPT. We write attention mechanism of hidden size s𝑠sitalic_s and sequence length L𝐿Litalic_L as

    Att(X)=(WOWVX)V multiplicationD1exp(X𝖳WK𝖳WQX)K-Q multiplicationd×L,Att𝑋subscriptsubscript𝑊𝑂subscript𝑊𝑉𝑋𝑉 multiplicationsubscriptsuperscript𝐷1superscript𝑋𝖳superscriptsubscript𝑊𝐾𝖳subscript𝑊𝑄𝑋𝐾-𝑄 multiplicationsuperscript𝑑𝐿\displaystyle{\rm Att}(X)=\underbrace{(W_{O}W_{V}X)}_{V\text{ multiplication}}% \underbrace{D^{-1}\exp(X^{\mathsf{T}}W_{K}^{\mathsf{T}}W_{Q}X)}_{K\text{-}Q% \text{ multiplication}}\in\mathbb{R}^{d\times L},roman_Att ( italic_X ) = under⏟ start_ARG ( italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_X ) end_ARG start_POSTSUBSCRIPT italic_V multiplication end_POSTSUBSCRIPT under⏟ start_ARG italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_exp ( start_ARG italic_X start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_X end_ARG ) end_ARG start_POSTSUBSCRIPT italic_K - italic_Q multiplication end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT , (4.2)

    with Ddiag(exp(XWQWK𝖳X𝖳)𝟙L)𝐷diag𝑋subscript𝑊𝑄superscriptsubscript𝑊𝐾𝖳superscript𝑋𝖳subscript1𝐿D\coloneqq\mathop{\rm{diag}}\left(\exp(XW_{Q}W_{K}^{\mathsf{T}}X^{\mathsf{T}})% \mathds{1}_{L}\right)italic_D ≔ roman_diag ( roman_exp ( start_ARG italic_X italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT end_ARG ) blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ). Here, exp()\exp(\cdot)roman_exp ( start_ARG ⋅ end_ARG ) is entry-wise exponential function, i.e. exp(A)i,j=exp(Ai,j)subscript𝐴𝑖𝑗subscript𝐴𝑖𝑗\exp{A}_{i,j}=\exp{A_{i,j}}roman_exp ( start_ARG italic_A end_ARG ) start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = roman_exp ( start_ARG italic_A start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT end_ARG ) for any matrix A𝐴Aitalic_A , diag()diag\mathop{\rm{diag}}\left(\cdot\right)roman_diag ( ⋅ ) converts a vector into a diagonal matrix with the vector’s entries on the diagonal, and 𝟙Lsubscript1𝐿\mathds{1}_{L}blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT is the length-L𝐿Litalic_L all ones vector.

  • (S2)

    Since V𝑉Vitalic_V multiplication is linear in weight while K𝐾Kitalic_K-Q𝑄Qitalic_Q multiplication is exponential in weights, we only need to focus on the gradient update of K𝐾Kitalic_K-Q𝑄Qitalic_Q multiplication. Therefore, for efficiency analysis of gradient, it is equivalent to analyze a reduced problem with fixed WOWVX=const.subscript𝑊𝑂subscript𝑊𝑉𝑋const.W_{O}W_{V}X=\text{const.}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT italic_X = const..

  • (S3)

    To focus on the DiT, we consider the low-dimensional linear encoder WAsubscript𝑊𝐴W_{A}italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT to be pretrained and to not participate in gradient computation. This aligns with common practice (Rombach et al., 2022) and is justified by the trivial computation cost due to the linearity of WAsubscript𝑊𝐴W_{A}italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT111The gradient computation is linear in WAsubscript𝑊𝐴W_{A}italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and hence the computation w.r.t. WAsubscript𝑊𝐴W_{A}italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT is cheap and upper-bounded by Lpoly(d)𝐿poly𝑑L\cdot\mathrm{poly}(d)italic_L ⋅ roman_poly ( italic_d ) time in a straightforward way..

  • (S4)

    To further simplify, we introduce A1,A2,A3s×Lsubscript𝐴1subscript𝐴2subscript𝐴3superscript𝑠𝐿A_{1},A_{2},A_{3}\in\mathbb{R}^{s\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_L end_POSTSUPERSCRIPT and Wd×d𝑊superscript𝑑𝑑W\in\mathbb{R}^{d\times d}italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT via

    WAR1(f𝒯(R(WAxti0)Xd×L)Y0,iYd×L)F2superscriptsubscriptnormsubscript𝑊𝐴superscript𝑅1subscript𝑓𝒯subscript𝑅superscriptsubscript𝑊𝐴topsubscript𝑥subscript𝑡subscript𝑖0absent𝑋absentsuperscript𝑑𝐿subscriptsubscript𝑌0𝑖absent𝑌absentsuperscript𝑑𝐿𝐹2\displaystyle\leavevmode\nobreak\ \Big{\|}W_{A}R^{-1}\big{(}f_{\mathcal{T}}(% \underbrace{R(W_{A}^{\top}x_{t_{i_{0}}})}_{\coloneqq X\in\mathbb{R}^{d\times L% }})-\underbrace{Y_{0,i}}_{\coloneqq Y\in\mathbb{R}^{d\times L}}\big{)}\Big{\|}% _{F}^{2}∥ italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( under⏟ start_ARG italic_R ( italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT ≔ italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) - under⏟ start_ARG italic_Y start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT ≔ italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (By (S0), (S1) and (S2))
    =\displaystyle== WAR1(WOWVWOVd×dXA3d×LD1exp(X𝖳A1L×dWK𝖳WQWd×dXA2d×Lmissing)Y)F2.superscriptsubscriptnormsubscript𝑊𝐴superscript𝑅1subscriptsubscript𝑊𝑂subscript𝑊𝑉absentsubscript𝑊𝑂𝑉absentsuperscript𝑑𝑑subscript𝑋absentsubscript𝐴3absentsuperscript𝑑𝐿superscript𝐷1subscriptsuperscript𝑋𝖳absentsuperscriptsubscript𝐴1topabsentsuperscript𝐿𝑑subscriptsuperscriptsubscript𝑊𝐾𝖳subscript𝑊𝑄absent𝑊absentsuperscript𝑑𝑑subscript𝑋absentsubscript𝐴2absentsuperscript𝑑𝐿missing𝑌𝐹2\displaystyle\leavevmode\nobreak\ \Big{\|}W_{A}R^{-1}\big{(}\underbrace{W_{O}W% _{V}}_{\coloneqq W_{OV}\in\mathbb{R}^{d\times d}}\underbrace{X}_{\coloneqq A_{% 3}\in\mathbb{R}^{d\times L}}D^{-1}\exp\big(\underbrace{X^{\mathsf{T}}}_{% \coloneqq A_{1}^{\top}\in\mathbb{R}^{L\times d}}\underbrace{W_{K}^{\mathsf{T}}% W_{Q}}_{\coloneqq W\in\mathbb{R}^{d\times d}}\underbrace{X}_{\coloneqq A_{2}% \in\mathbb{R}^{d\times L}}\big{missing})-Y\big{)}\Big{\|}_{F}^{2}.∥ italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT ≔ italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_X end_ARG start_POSTSUBSCRIPT ≔ italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_exp ( start_ARG under⏟ start_ARG italic_X start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT ≔ italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT ≔ italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_X end_ARG start_POSTSUBSCRIPT ≔ italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_missing end_ARG ) - italic_Y ) ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (4.3)

    Notably, A1,A2,A3,X,Ysubscript𝐴1subscript𝐴2subscript𝐴3𝑋𝑌A_{1},A_{2},A_{3},X,Yitalic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_X , italic_Y are constants w.r.t. training above loss with gradient updates.

Therefore, we simplify the objective of training DiT into

Definition 4.1 (Training Generic DiT Loss).

Given A1,A2,A3,Yd×Lsubscript𝐴1subscript𝐴2subscript𝐴3𝑌superscript𝑑𝐿A_{1},A_{2},A_{3},Y\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT and WOV,Wd×dsubscript𝑊𝑂𝑉𝑊superscript𝑑𝑑W_{OV},W\in\mathbb{R}^{d\times d}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT , italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT following (S4), Training a DiT with 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT loss on a single data point X,Yd×L𝑋𝑌superscript𝑑𝐿X,Y\in\mathbb{R}^{d\times L}italic_X , italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT is formulated as

minW0(W)=minW12WAR1(WOVA3D1exp(A1WA2)Y)F2.subscript𝑊subscript0𝑊subscript𝑊12superscriptsubscriptnormsubscript𝑊𝐴superscript𝑅1subscript𝑊𝑂𝑉subscript𝐴3superscript𝐷1superscriptsubscript𝐴1top𝑊subscript𝐴2𝑌𝐹2\displaystyle\min_{W}\leavevmode\nobreak\ \mathcal{L}_{0}(W)=\min_{W}% \leavevmode\nobreak\ {\frac{1}{2}}\Big{\|}W_{A}R^{-1}\big{(}W_{OV}A_{3}D^{-1}% \exp(A_{1}^{\top}WA_{2})-Y\big{)}\Big{\|}_{F}^{2}.roman_min start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_W ) = roman_min start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_exp ( start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) - italic_Y ) ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (4.4)

Here D:=diag(exp(A1WA2)𝟙n)L×Lassign𝐷diagsuperscriptsubscript𝐴1top𝑊subscript𝐴2subscript1𝑛superscript𝐿𝐿D:=\mathop{\rm{diag}}(\exp(A_{1}^{\top}WA_{2}){\mathds{1}}_{n})\in\mathbb{R}^{% L\times L}italic_D := roman_diag ( roman_exp ( start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) blackboard_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT.

Remark 4.1 (Conditional and Unconditional Generation).

0subscript0\mathcal{L}_{0}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is generic. If A1A2d×Lsubscript𝐴1subscript𝐴2superscript𝑑𝐿A_{1}\neq A_{2}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≠ italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, Definition 4.1 reduces to cross-attention in DiT score net (for conditional generation). If A1=A2d×Lsubscript𝐴1subscript𝐴2superscript𝑑𝐿A_{1}=A_{2}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, Definition 4.1 reduces to self-attention in DiT score net (for unconditional vanilla generation).

We introduce the next problem to characterize all possible gradient computations of optimizing (4.4).

Problem 1 (Approximate DiT Gradient Computation (ADiTGC(L,d,Γ,ϵ)ADiTGC𝐿𝑑Γitalic-ϵ\textsc{ADiTGC}(L,d,\Gamma,\epsilon)ADiTGC ( italic_L , italic_d , roman_Γ , italic_ϵ ))).

Given A1,A2,A3,Yd×Lsubscript𝐴1subscript𝐴2subscript𝐴3𝑌superscript𝑑𝐿A_{1},A_{2},A_{3},Y\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. Let ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. Assume all numerical values are in 𝒪(log(L))𝒪𝐿\mathcal{O}(\log(L))caligraphic_O ( roman_log ( start_ARG italic_L end_ARG ) )-bits encoding. Let loss function 0subscript0\mathcal{L}_{0}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT follow Definition 4.1. The problem of approximating gradient computation of optimizing empirical DiT loss (4.4) is to find an approximated gradient matrix G~(W)d×dsuperscript~𝐺𝑊superscript𝑑𝑑\tilde{G}^{(W)}\in\mathbb{R}^{d\times d}over~ start_ARG italic_G end_ARG start_POSTSUPERSCRIPT ( italic_W ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT such that G¯~(W)W¯max1/poly(L)subscriptnormsuperscript¯~𝐺𝑊partial-derivative¯𝑊1poly𝐿\big{\|}\underline{\tilde{G}}^{(W)}-\partialderivative{\mathcal{L}}{\underline% {W}}\big{\|}_{\max}\leq 1/\mathrm{poly}(L)∥ under¯ start_ARG over~ start_ARG italic_G end_ARG end_ARG start_POSTSUPERSCRIPT ( italic_W ) end_POSTSUPERSCRIPT - divide start_ARG ∂ start_ARG caligraphic_L end_ARG end_ARG start_ARG ∂ start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ 1 / roman_poly ( italic_L ). Here, Amaxmaxi,j|Aij|subscriptnorm𝐴subscript𝑖𝑗subscript𝐴𝑖𝑗\norm{A}_{\max}\coloneqq\max_{i,j}\absolutevalue{A_{ij}}∥ start_ARG italic_A end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≔ roman_max start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT | start_ARG italic_A start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG | for any matrix A𝐴Aitalic_A.

In this work, we aim to investigate the computational limits of all possible efficient algorithms of ADiTGC with ϵ=1/poly(L)italic-ϵ1poly𝐿\epsilon=1/\mathrm{poly}(L)italic_ϵ = 1 / roman_poly ( italic_L ). Yet, the explicit gradient of DiT denoising score matching loss (4.4) is too complicated to characterize ADiTGC. To combat this, we make the following observations.

  • (O1)

    Let g1()WAR1():d×Ld0:subscript𝑔1subscript𝑊𝐴superscript𝑅1superscript𝑑𝐿superscriptsubscript𝑑0g_{1}(\cdot)\coloneqq W_{A}R^{-1}(\cdot):\mathbb{R}^{d\times L}\to\mathbb{R}^{% d_{0}}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ) ≔ italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, g2()Att():d×Ld×L:subscript𝑔2Attsuperscript𝑑𝐿superscript𝑑𝐿g_{2}(\cdot)\coloneqq{\rm Att}(\cdot):\mathbb{R}^{d\times L}\to\mathbb{R}^{d% \times L}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ ) ≔ roman_Att ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, and g3()R(WA):Dd×Lg_{3}(\cdot)\coloneqq R(W_{A}^{\top}\cdot):\mathbb{R}^{D}\to\mathbb{R}^{d% \times L}italic_g start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( ⋅ ) ≔ italic_R ( italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT such that g3(x)=Xsubscript𝑔3𝑥𝑋g_{3}(x)=Xitalic_g start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_x ) = italic_X for xD𝑥superscript𝐷x\in\mathbb{R}^{D}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT (with D>d0=dL𝐷subscript𝑑0𝑑𝐿D>d_{0}=dLitalic_D > italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d italic_L).

  • (O2)

    Vectorization of f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT. For the ease of presentation, we use notation flexibly that f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT to denote both a matrix in d×Lsuperscript𝑑𝐿\mathbb{R}^{d\times L}blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT and a vector in dLsuperscript𝑑𝐿\mathbb{R}^{dL}blackboard_R start_POSTSUPERSCRIPT italic_d italic_L end_POSTSUPERSCRIPT in the following analysis. This practice does not affect correctness. The context in which f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT is used should clarify whether it refers to a matrix or a vector. Explicit vectorization follows Definition D.1.

  • (O3)

    Linearity of g1subscript𝑔1g_{1}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. By linearity of WAR1()subscript𝑊𝐴superscript𝑅1W_{A}R^{-1}(\cdot)italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ⋅ ), we treat g1subscript𝑔1g_{1}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT as a matrix in d0×dLsuperscriptsubscript𝑑0𝑑𝐿\mathbb{R}^{d_{0}\times dL}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT × italic_d italic_L end_POSTSUPERSCRIPT acting on vector f𝒯()dLsubscript𝑓𝒯superscript𝑑𝐿f_{\mathcal{T}}(\cdot)\in\mathbb{R}^{dL}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( ⋅ ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_L end_POSTSUPERSCRIPT.

Therefore, we have 0=g1[g2(g3)Y]22subscript0superscriptsubscriptnormsubscript𝑔1delimited-[]subscript𝑔2subscript𝑔3𝑌22\mathcal{L}_{0}=\norm{g_{1}\cdot\left[g_{2}(g_{3})-Y\right]}_{2}^{2}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ∥ start_ARG italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋅ [ italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_g start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) - italic_Y ] end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, such that its gradient involves d0dW=g1dg2dWderivative𝑊subscript0subscript𝑔1derivative𝑊subscript𝑔2\derivative{\mathcal{L}_{0}}{W}=g_{1}\derivative{g_{2}}{W}divide start_ARG roman_d start_ARG caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG italic_W end_ARG end_ARG = italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG italic_W end_ARG end_ARG. From above, we only need to focus on proving the computation time and error control of term dg2dWderivative𝑊subscript𝑔2\derivative{g_{2}}{W}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG italic_W end_ARG end_ARG for gradient w.r.t W𝑊Witalic_W. Luckily, with tools from fine-grained complexity theory (Alman and Song, 2023) and tensor trick (see Section D.3), we prove the existence of almost-linear time algorithms for Problem 1 in the next theorem. Let vec(W)W¯vec𝑊¯𝑊\operatorname{vec}(W)\coloneqq\underline{W}roman_vec ( italic_W ) ≔ under¯ start_ARG italic_W end_ARG for any matrix W𝑊Witalic_W following Definition D.1.

Theorem 4.1 (Existence of Almost-Linear Time Algorithms for ADiTGC).

Suppose all numerical values are in 𝒪(logL)𝒪𝐿\mathcal{O}(\log L)caligraphic_O ( roman_log italic_L )-bits encoding. Let max(WOVA3max,WKA1max,WQA2max)Γsubscriptnormsubscript𝑊𝑂𝑉subscript𝐴3subscriptnormsubscript𝑊𝐾subscript𝐴1subscriptnormsubscript𝑊𝑄subscript𝐴2Γ\max(\|W_{OV}A_{3}\|_{\max},\norm{W_{K}A_{1}}_{\max},\norm{W_{Q}A_{2}}_{\max})\leq\Gammaroman_max ( ∥ italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ≤ roman_Γ. There exists a L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time algorithm to solve ADiTGC(Lp,L,d=𝒪(logL),Γ=o(logL))\textsc{ADiTGC}(L_{p},L,d=\mathcal{O}(\log L),\Gamma=o(\sqrt{\log L}))ADiTGC ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , italic_L , italic_d = caligraphic_O ( roman_log italic_L ) , roman_Γ = italic_o ( square-root start_ARG roman_log italic_L end_ARG ) ) (i.e., Problem 1) with loss 0subscript0\mathcal{L}_{0}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT from Definition 4.1 up to 1/poly(L)1poly𝐿1/\mathrm{poly}(L)1 / roman_poly ( italic_L ) accuracy. In particular, this algorithm outputs gradient matrices G~(W)d×dsuperscript~𝐺𝑊superscript𝑑𝑑\tilde{G}^{(W)}\in\mathbb{R}^{d\times d}over~ start_ARG italic_G end_ARG start_POSTSUPERSCRIPT ( italic_W ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT such that G¯~(W)W¯max1/poly(L)subscriptnormsuperscript¯~𝐺𝑊partial-derivative¯𝑊1poly𝐿\big{\|}\underline{\tilde{G}}^{(W)}-\partialderivative{\mathcal{L}}{\underline% {W}}\big{\|}_{\max}\leq 1/\mathrm{poly}(L)∥ under¯ start_ARG over~ start_ARG italic_G end_ARG end_ARG start_POSTSUPERSCRIPT ( italic_W ) end_POSTSUPERSCRIPT - divide start_ARG ∂ start_ARG caligraphic_L end_ARG end_ARG start_ARG ∂ start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ 1 / roman_poly ( italic_L ).

Proof Sketch.

Our proof is built on the key observation that there exist low-rank structures within the DiT training gradients. Using the tensor trick (Diao et al., 2019, 2018) and computational hardness results of attention (Hu et al., 2024c; Alman and Song, 2023), we approximate DiT training gradients with a series of low-rank approximations and carefully match the multiplication dimensions so that the computation of dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG forms a chained low-rank approximation. We complete the proof by demonstrating that this approximation is bounded by a 1/poly(L)1poly𝐿1/\mathrm{poly}(L)1 / roman_poly ( italic_L ) error and requires only almost-linear time. See Section G.2 for a detailed proof. ∎

Remark 4.2.

We remark that Theorem 4.1 is dominated by the relation between L𝐿Litalic_L and d𝑑ditalic_d, hence by the subspace dimension222See Assumption 2.1. d0=dLsubscript𝑑0𝑑𝐿d_{0}=dLitalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d italic_L. A smaller d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT makes Theorem 4.1 more likely to hold.

4.2 Computational Limits of Forward Inference

Since the inference of score-matching diffusion models is a forward pass of the trained score estimator sWsubscript𝑠𝑊s_{W}italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT, the computational hardness of DiT ties to the transformer-based score network,

sW(A1,A2,A3)=WAR1(WOVA3d×LD1L×Lexp(A1WKL×sWQA2d×Lmissing)),subscript𝑠𝑊subscript𝐴1subscript𝐴2subscript𝐴3subscript𝑊𝐴superscript𝑅1subscriptsubscript𝑊𝑂𝑉subscript𝐴3𝑑𝐿subscriptsuperscript𝐷1𝐿𝐿subscriptsuperscriptsubscript𝐴1topsuperscriptsubscript𝑊𝐾top𝐿𝑠subscriptsubscript𝑊𝑄subscript𝐴2𝑑𝐿missing\displaystyle s_{W}(A_{1},A_{2},A_{3})=W_{A}R^{-1}\big{(}\underbrace{W_{OV}A_{% 3}}_{d\times L}\underbrace{D^{-1}}_{L\times L}\exp\big(\underbrace{A_{1}^{\top% }W_{K}^{\top}}_{L\times s}\underbrace{W_{Q}A_{2}}_{d\times L}\big{missing})% \big{)},italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) = italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_L end_POSTSUBSCRIPT roman_exp ( start_ARG under⏟ start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_s end_POSTSUBSCRIPT under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT roman_missing end_ARG ) ) , (4.5)

following notation in Definition 4.1. For inference, we study the following approximation problem. Notably, by Remark 4.1, (4.5) subsumes both conditional and unconditional DiT inferences.

Problem 2 (Approximate DiT Inference ADiTI(d,L,Γ,δF)ADiTI𝑑𝐿Γsubscript𝛿𝐹\textsc{ADiTI}(d,L,\Gamma,\delta_{F})ADiTI ( italic_d , italic_L , roman_Γ , italic_δ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT )).

Let δF>0subscript𝛿𝐹0\delta_{F}>0italic_δ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT > 0 and B>0𝐵0B>0italic_B > 0. Given A1,A2,A3d×Lsubscript𝐴1subscript𝐴2subscript𝐴3superscript𝑑𝐿A_{1},A_{2},A_{3}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, and WOV,WK,WQd×dsubscript𝑊𝑂𝑉subscript𝑊𝐾subscript𝑊𝑄superscript𝑑𝑑W_{OV},W_{K},W_{Q}\in\mathbb{R}^{d\times d}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT with guarantees that WOVA3Bsubscriptnormsubscript𝑊𝑂𝑉subscript𝐴3𝐵\norm{W_{OV}A_{3}}_{\infty}\leq B∥ start_ARG italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B, WKA1Bsubscriptnormsubscript𝑊𝐾subscript𝐴1𝐵\norm{W_{K}A_{1}}_{\infty}\leq B∥ start_ARG italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B and WQA2Bsubscriptnormsubscript𝑊𝑄subscript𝐴2𝐵\norm{W_{Q}A_{2}}_{\infty}\leq B∥ start_ARG italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B, we aim to study an approximation problem ADiTI(d,L,B,δF)ADiTI𝑑𝐿𝐵subscript𝛿𝐹\textsc{ADiTI}(d,L,B,\delta_{F})ADiTI ( italic_d , italic_L , italic_B , italic_δ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ), that approximates sW(A1,A2,A3)subscript𝑠𝑊subscript𝐴1subscript𝐴2subscript𝐴3s_{W}(A_{1},A_{2},A_{3})italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) with a vector z~d0~𝑧superscriptsubscript𝑑0\tilde{z}\in\mathbb{R}^{d_{0}}over~ start_ARG italic_z end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (with d0=dLsubscript𝑑0𝑑𝐿d_{0}=d\cdot Litalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d ⋅ italic_L) such that z~WAR1(WOVA3D1exp(A1WKWQA2))maxδFsubscriptnorm~𝑧subscript𝑊𝐴superscript𝑅1subscript𝑊𝑂𝑉subscript𝐴3superscript𝐷1superscriptsubscript𝐴1topsuperscriptsubscript𝑊𝐾topsubscript𝑊𝑄subscript𝐴2subscript𝛿𝐹\norm{\tilde{z}-W_{A}R^{-1}\left(W_{OV}A_{3}D^{-1}\exp(A_{1}^{\top}W_{K}^{\top% }W_{Q}A_{2})\right)}_{\max}\leq\delta_{F}∥ start_ARG over~ start_ARG italic_z end_ARG - italic_W start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_exp ( start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) ) end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_δ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT. Here, Amaxmaxi,j|Aij|subscriptnorm𝐴subscript𝑖𝑗subscript𝐴𝑖𝑗\norm{A}_{\max}\coloneqq\max_{i,j}\absolutevalue{A_{ij}}∥ start_ARG italic_A end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≔ roman_max start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT | start_ARG italic_A start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG | for any matrix A𝐴Aitalic_A.

By (O2) and (O3), we make an observation that Problem 2 is just a special case of (Alman and Song, 2023). Hence, we characterize the all possible efficient algorithms for ADiTI with next proposition.

Proposition 4.1 (Norm-Based Efficiency Phase Transition).

Let WQA2Bsubscriptnormsubscript𝑊𝑄subscript𝐴2𝐵\norm{W_{Q}A_{2}}_{\infty}\leq B∥ start_ARG italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B, WKA1Bsubscriptnormsubscript𝑊𝐾subscript𝐴1𝐵\norm{W_{K}A_{1}}_{\infty}\leq B∥ start_ARG italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B and WOVA3Bsubscriptnormsubscript𝑊𝑂𝑉subscript𝐴3𝐵\norm{W_{OV}A_{3}}_{\infty}\leq B∥ start_ARG italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_B with B=𝒪(logL)𝐵𝒪𝐿B=\mathcal{O}(\sqrt{\log L})italic_B = caligraphic_O ( square-root start_ARG roman_log italic_L end_ARG ). Assuming SETH (Hypothesis 1), for every q>0𝑞0q>0italic_q > 0, there are constants C,Ca,Cb>0𝐶subscript𝐶𝑎subscript𝐶𝑏0C,C_{a},C_{b}>0italic_C , italic_C start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT > 0 such that: there is no O(n2q)𝑂superscript𝑛2𝑞O(n^{2-q})italic_O ( italic_n start_POSTSUPERSCRIPT 2 - italic_q end_POSTSUPERSCRIPT )-time (sub-quadratic) algorithm for the problem ADiTI(L,d=ClogL,B=CblogL,δF=LCa)\textsc{ADiTI}(L,d=C\log L,B=C_{b}\sqrt{\log L},\delta_{F}=L^{-C_{a}})ADiTI ( italic_L , italic_d = italic_C roman_log italic_L , italic_B = italic_C start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT square-root start_ARG roman_log italic_L end_ARG , italic_δ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT - italic_C start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ).

Remark 4.3.

Proposition 4.1 suggests an efficiency threshold for the upper bound of WKA1subscriptnormsubscript𝑊𝐾subscript𝐴1\norm{W_{K}A_{1}}_{\infty}∥ start_ARG italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT, WQA2subscriptnormsubscript𝑊𝑄subscript𝐴2\norm{W_{Q}A_{2}}_{\infty}∥ start_ARG italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT, WOVA3subscriptnormsubscript𝑊𝑂𝑉subscript𝐴3\norm{W_{OV}A_{3}}_{\infty}∥ start_ARG italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT. Only below this threshold are efficient algorithms for Problem 2 possible.

Moreover, there exists almost-linear DiT inference algorithms following (Alman and Song, 2023).

Proposition 4.2 (Almost-Linear Time DiT Inference).

Assuming SETH, the DiT inference problem ADiTI(L,d=𝒪(logL),B=o(logL),δF=1/poly(L))\textsc{ADiTI}(L,d=\mathcal{O}(\log L),B=o(\sqrt{\log L}),\delta_{F}=1/\mathrm% {poly}(L))ADiTI ( italic_L , italic_d = caligraphic_O ( roman_log italic_L ) , italic_B = italic_o ( square-root start_ARG roman_log italic_L end_ARG ) , italic_δ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT = 1 / roman_poly ( italic_L ) ) can be solved in L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time.

Remark 4.4.

Proposition 4.2 is a special case of Proposition 4.1 under the efficiency threshold.

Remark 4.5.

Propositions 4.2 and 4.1 are dominated by the relation between L𝐿Litalic_L and d𝑑ditalic_d, hence by the subspace dimension d0=dLsubscript𝑑0𝑑𝐿d_{0}=dLitalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_d italic_L. A smaller d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT makes Propositions 4.2 and 4.1 more likely to hold.

5 Discussion and Conclusion

We explore the fundamental limits of latent DiTs with 3 key contributions. First, we prove that transformers are universal approximators for the score functions in DiTs (Theorem 3.1), with approximation capacity and model size dependent only on the latent dimension, suggesting DiTs can handle high-dimensional data challenges. Second, we show that Transformer-based score estimators converge to the true score function (Corollary 3.1.1), ensuring the generated data distribution closely approximates the original (Corollary 3.1.2). Third, we provide provably efficient criteria (Proposition 4.1) and prove the existence of almost-linear time algorithms for forward inference (Proposition 4.2) and backward computation (Theorem 4.1). These results highlight the potential of latent DiTs to achieve both computational efficiency and robust performance in practical scenarios.

Limitations and Future Direction. As discussed in Remark 3.4, the double exponential factor in our explicit sample complexity bound (Corollary 3.1.1) suggests a possible gap in our understanding of transformer universality and its interplay with DiT architecture. This motivate us to rethink about transformer universality and explore new proof techniques for DiTs, which we leave for future work. Besides, due to its formal nature, this work do not provide immediate practical implementations. However, we expect that our findings provide valuable insights for future diffusion generative models.

Broader Impact

This theoretical work aims to shed light on the foundations of diffusion generative models and is not anticipated to have negative social impacts.

Acknowledgments

JH would like to thank to Minshuo Chen, Sophia Pi, Yibo Wen, Tim Tsz-Kit Lau, Chenwei Xu, Dino Feng and Andrew Chen for enlightening discussions on related topics, and the Red Maple Family for support.

JH is partially supported by the Walter P. Murphy Fellowship. HL is partially supported by NIH R01LM1372201. The content is solely the responsibility of the authors and does not necessarily represent the official views of the funding agencies.

Appendix

\startcontents

[sections] \printcontents[sections] 1

Appendix A More Discussion on Low-Dimensional Linear Latent Space

Our analysis is based on the low-dimensional linear latent space assumption, here we give a further discussion about it with our theoretical results.

The low-dimensional data structure in Assumption 2.1 indicates robust and informative latent representation feature space. Besides, it improves computational efficiency by reducing data complexity without sacrificing essential information. This is consistent with the analysis in our work. Similar to the results under Assumption 2.1 (d0<Dsubscript𝑑0𝐷d_{0}<Ditalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < italic_D), it is easy to find that our theoretical results hold in other two settings: d0=Dsubscript𝑑0𝐷d_{0}=Ditalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_D and d0>Dsubscript𝑑0𝐷d_{0}>Ditalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT > italic_D.

  • Statistically, for score approximation, score estimation, and distribution estimation, the upper bound depends on the dimension of the latent variable d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, other than d𝑑ditalic_d. A smaller d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT allows for a reduced model size to achieve a specified approximation error compared to larger one (Theorem 3.1). Additionally, with a smaller d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, both score and distribution estimation errors are reduced relative to scenarios with larger one (Corollary 3.1.1 and Corollary 3.1.2).

  • Computationally, smaller d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT benefits the provably efficient criteria (Proposition 4.1, almost-linear time algorithms for forward inference (Proposition 4.2) and backward computation (Theorem 4.1).

Appendix B Nomenclature Table

We summarize our notations in the following table for easy reference.

Table 1: Mathematical Notations and Symbols
Symbol Description
z2subscriptnorm𝑧2\norm{z}_{2}∥ start_ARG italic_z end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT Euclidean norm, where z𝑧zitalic_z is a vector
zsubscriptnorm𝑧\norm{z}_{\infty}∥ start_ARG italic_z end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT Infinite norm, where z𝑧zitalic_z is a vector
Z2subscriptnorm𝑍2\norm{Z}_{2}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT 2-norm, where Z𝑍Zitalic_Z is a matrix
Zopsubscriptnorm𝑍op\norm{Z}_{\rm op}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT roman_op end_POSTSUBSCRIPT Operator norm, where Z𝑍Zitalic_Z is a matrix
ZFsubscriptnorm𝑍𝐹\norm{Z}_{F}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT Frobenius norm, where Z𝑍Zitalic_Z is a matrix
Zp,qsubscriptnorm𝑍𝑝𝑞\norm{Z}_{p,q}∥ start_ARG italic_Z end_ARG ∥ start_POSTSUBSCRIPT italic_p , italic_q end_POSTSUBSCRIPT p,q𝑝𝑞p,qitalic_p , italic_q-norm, where Z𝑍Zitalic_Z is a matrix
f(x)L2subscriptnorm𝑓𝑥superscript𝐿2\norm{f(x)}_{L^{2}}∥ start_ARG italic_f ( italic_x ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-norm, where f𝑓fitalic_f is a function
f(x)L2(P)subscriptnorm𝑓𝑥superscript𝐿2𝑃\norm{f(x)}_{L^{2}(P)}∥ start_ARG italic_f ( italic_x ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P ) end_POSTSUBSCRIPT L2(P)superscript𝐿2𝑃L^{2}(P)italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P )-norm, where f𝑓fitalic_f is a function and P𝑃Pitalic_P is a distribution
f()Lipsubscriptnorm𝑓𝐿𝑖𝑝\norm{f(\cdot)}_{Lip}∥ start_ARG italic_f ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L italic_i italic_p end_POSTSUBSCRIPT Lipschitz-norm, where f𝑓fitalic_f is a function
fPsubscript𝑓𝑃f_{\sharp}Pitalic_f start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT italic_P Pushforward measure, where f𝑓fitalic_f is a function and P𝑃Pitalic_P is a distribution
n𝑛nitalic_n Sample size
x𝑥xitalic_x Data point in original data space, xD𝑥superscript𝐷x\in\mathbb{R}^{D}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT
hhitalic_h Latent variable in low-dimensional subspace, hd0superscriptsubscript𝑑0h\in\mathbb{R}^{d_{0}}italic_h ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
phsubscript𝑝p_{h}italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT The destiny function of hhitalic_h
B𝐵Bitalic_B The matrix with orthonormal columns to transform hhitalic_h to x𝑥xitalic_x, where BD×d0𝐵superscript𝐷subscript𝑑0B\in\mathbb{R}^{D\times d_{0}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
h¯¯\bar{h}over¯ start_ARG italic_h end_ARG h¯=Bx¯superscript𝐵top𝑥\bar{h}=B^{\top}xover¯ start_ARG italic_h end_ARG = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x
T𝑇Titalic_T Stop** time in forward process of Diffusion model
T0subscript𝑇0T_{0}italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT Stop** time in backward process of Diffusion model
μ𝜇\muitalic_μ Discretized step size in backward process
pt()subscript𝑝𝑡p_{t}(\cdot)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) The density function of x𝑥xitalic_x for at time t𝑡titalic_t
pth()superscriptsubscript𝑝𝑡p_{t}^{h}(\cdot)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( ⋅ ) The density function of h¯¯\bar{h}over¯ start_ARG italic_h end_ARG at time t𝑡titalic_t
ψ𝜓\psiitalic_ψ (Conditional) Gaussian density function
d𝑑ditalic_d Input dimension of each token in the Transformer network of DiT
L𝐿Litalic_L Token length in the Transformer network of DiT
X𝑋Xitalic_X Sequence input of Transformer network in DiT, where Xd×L𝑋superscript𝑑𝐿X\in\mathbb{R}^{d\times L}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT
E𝐸Eitalic_E Position encoding, where Ed×L𝐸superscript𝑑𝐿E\in\mathbb{R}^{d\times L}italic_E ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT
R()𝑅R(\cdot)italic_R ( ⋅ ) Reshape layer in DiT, R():d0d×L:𝑅superscriptsubscript𝑑0superscript𝑑𝐿R(\cdot):\mathbb{R}^{d_{0}}\to\mathbb{R}^{d\times L}italic_R ( ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT
WBsubscript𝑊𝐵W_{B}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT The orthonormal matrix to approximate B𝐵Bitalic_B, where WBD×d0subscript𝑊𝐵superscript𝐷subscript𝑑0W_{B}\in\mathbb{R}^{D\times d_{0}}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT

Appendix C Related Works

Diffusion (Ho et al., 2020) and score-based generative models (Song and Ermon, 2019) have been particularly successful as generative models of images, video and biomedical data (Nichol et al., 2021; Ramesh et al., 2022; Liu et al., 2024; Zhou et al., 2024a, b; Wang et al., 2024a, b). There are two popular directions in this direction. Empirically, diffusion transformers (DiTs) (Peebles and Xie, 2023) have emerged as a significant advancement, effectively combining the strengths of transformer architectures and diffusion-based approaches. Theoretically, the development of the approximation theory for diffusion models supports their practical success, providing a theoretical framework for understanding and enhancing their effectiveness in various applications (Chen et al., 2023a).

Organization.

In the following, we first discuss recent developments in DiTs. Then, we discuss the main technique of our statistical results: the universality (universal approximation) of transformer. Next, we discuss recent theoretical developments in diffusion generative models. Lastly, we discuss other aspects of transformer in foundation models beyond diffusion models.

Diffusion Transformers.

Recently, transformer-based diffusion models have garnered significant attention in research. The U-ViT model (Bao et al., 2022) incorporates transformer blocks into a U-net architecture, treating all inputs as tokens. In contrast, DiT (Peebles and Xie, 2023) utilizes a straightforward, non-hierarchical transformer structure. Models like MDT (Gao et al., 2023a) and MaskDiT (Zheng et al., 2023) improve the training efficiency of DiT by applying a masking strategy.

Universality and Memory Capacity of Transformers.

The universality of transformers refers to their ability to serve as universal approximators. This means that transformers theoretically models any sequence-to-sequence function to a desired degree of accuracy. Yun et al. (2020) establish that transformers can universally approximate sequence-to-sequence functions by stacking numerous layers of feed-forward functions and self-attention functions. In a different approach, Jiang and Li (2023) affirm the universality of transformers by utilizing the Kolmogorov-Albert representation Theorem. Most recently, Kajitsuka and Sato (2023) show that transformers with one self-attention layer is a universal approximator.

The memory capacity of a transformer is a practical measure to test the theoretical results of the transformer’s universality, by ensuring the model can handle necessary context and dependencies. By memory capacity, we refer to the minimal set of parameters such that the model (i.e., transformer) approximates all input-output pairs in the training dataset with a bounded error. Several works address the memory capacity of transformers. Kim et al. (2022) show that transformers with O~(d+L+NL)~𝑂𝑑𝐿𝑁𝐿\tilde{O}(d+L+\sqrt{NL})over~ start_ARG italic_O end_ARG ( italic_d + italic_L + square-root start_ARG italic_N italic_L end_ARG ) parameters are sufficient to memorize N𝑁Nitalic_N length-L𝐿Litalic_L and dimension-d𝑑ditalic_d sequence-to-sequence data points by constructing a contextual map** with 𝒪(L)𝒪𝐿\mathcal{O}(L)caligraphic_O ( italic_L ) attention layers. Mahdavi et al. (2023) show that a multi-head-attention with hhitalic_h heads is able to memorize 𝒪(hL)𝒪𝐿\mathcal{O}(hL)caligraphic_O ( italic_h italic_L ) examples under a linear independence data assumption. Kajitsuka and Sato (2023) show that a single layer transformer with 𝒪(NLd+d2)𝒪𝑁𝐿𝑑superscript𝑑2\mathcal{O}(NLd+d^{2})caligraphic_O ( italic_N italic_L italic_d + italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) parameters is able to memorize N𝑁Nitalic_N length-L𝐿Litalic_L and dimension-d𝑑ditalic_d sequence-to-sequence data points by utilizing the connection between the softmax function and Boltzmann operator. Wang et al. (2023) extend the results of (Yun et al., 2020) to prompt tuning and discuss the memorization of only the last token of each data sequence. Another line of research establishes a different kind of memory capacity for transformers by connecting transformer attention with dense associative memory models (modern Hopfield models) (Hu et al., 2024a, b, c, 2023; Wu et al., 2024a, b; Ramsauer et al., 2020). Notably, they define memory capacity as the smallest number of (length-L𝐿Litalic_L and dimension-d𝑑ditalic_d) data points the model (transformer attention) is able to store and derive exponential-in-d𝑑ditalic_d high-probability capacity lower bounds.

Our work is motivated by and builds on (Yun et al., 2020) to bridge the transformer’s function approximation ability with data distribution estimation. While we do not address the memorization of DiTs (or diffusion models in general), recent studies on dense associative models suggest viewing pretrained diffusion generative models as associative memory models (Hoover et al., 2023; Ambrogioni, 2023). We plan to explore this aspect in future work.

Theories of Diffusion Models.

In addition to empirical success, there has been several theoretical analysis about diffusion models. Chen et al. (2023a) studies score approximation, estimation, and distribution recovery of U-Net based diffusion models. Benton et al. (2024) provide convergence bounds linear in data dimensions, assuming accurate score function approximation. Zhu et al. (2023); Wibisono et al. (2024) provide statistical sample complexity bounds for score-matching under the similar assumptions. Oko et al. (2023) analyze the distribution estimation under the assumption that the initial density is supported on [1,1]Dsuperscript11𝐷[-1,1]^{D}[ - 1 , 1 ] start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and smooth in the boundary.

Among these works, our work is built on and closest to (Chen et al., 2023a), as both assume the data has a low-dimensional structure. However, our work differs in three key aspects. First, beyond the simple ReLU networks considered in (Chen et al., 2023a), we provide the first score approximation analysis for DiTs with a transformer-based score estimator. Second, our work is the first to provide the statistical rates of DiTs (score and distribution estimation) based on transformer universality (Yun et al., 2020) and norm-based converging number bound (Edelman et al., 2022), supporting the practical success of DiTs (Esser et al., 2024; Ma et al., 2024). Lastly, our work provides the first comprehensive analysis of the computational limits and all possible efficient DiT algorithms/methods for both forward inference and backward training. This offers timely insights into the empirical computational inefficiency of DiTs (Liu et al., 2024) and guidance for future DiT architectures.

Transformers in Foundation Models: Transformer-Based Pretrained Models.

Transformer-based pretrained models utilize attention mechanisms to process sequential data, enabling the learning of contextual relationships for tasks like natural language understanding and generation. These models encompass three types: encoder-based, decoder-based, and diffusion transformers. Encoder-based transformers, such as DNABERT (Zhou et al., 2024c, 2023; Ji et al., 2021), employ bidirectional attention to extract feature representations DNABERT shows great potential to capture complex patterns of genome sequences and improve tasks such as gene prediction. Decoder-based transformers generate output sequences from encoded information using unidirectional attention, such as ChatGPT (Lagler et al., 2013; Floridi and Chiriatti, 2020; Brown et al., 2020) for natural language. The diffusion transformers generate a sequence toward a target distribution, such as Sora (Liu et al., 2024) and Videofusion (Luo et al., 2023) for video generation and DecompDiff (Guan et al., 2024) for drug design. In our paper, we present an early exploration of the statistical and computational limits of diffusion transformer models.

Appendix D Supplementary Theoretical Background

In this section, we provide some further background. We show the details about the forward and backward process in Diffusion Models in Section D.1. Besides, we give the details of the proof about the score decomposition in Section D.2.

D.1 Diffusion Models

Forward Process.

Diffusion models gradually add noise to the original data in the forward process. We describe the forward process as the following SDE

dxt=12w(t)xtdt+w(t)dBt,xtD,formulae-sequencesubscript𝑥𝑡12𝑤𝑡subscript𝑥𝑡𝑡𝑤𝑡subscript𝐵𝑡subscript𝑥𝑡superscript𝐷\displaystyle\differential x_{t}=-\frac{1}{2}w(t)x_{t}\differential t+\sqrt{w(% t)}\differential B_{t},\leavevmode\nobreak\ x_{t}\in\mathbb{R}^{D},start_DIFFOP roman_d end_DIFFOP italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_w ( italic_t ) italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_DIFFOP roman_d end_DIFFOP italic_t + square-root start_ARG italic_w ( italic_t ) end_ARG start_DIFFOP roman_d end_DIFFOP italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT , (D.1)

where x0P0similar-tosubscript𝑥0subscript𝑃0x_{0}\sim P_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, (Bt)t0subscriptsubscript𝐵𝑡𝑡0(B_{t})_{t\geq 0}( italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ≥ 0 end_POSTSUBSCRIPT is a standard Brownian motion, and w(t)>0𝑤𝑡0w(t)>0italic_w ( italic_t ) > 0 is a nondecreasing weighting function. Let Ptsubscript𝑃𝑡P_{t}italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT denote the marginal distribution and destiny of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The conditional distribution P(xt|x0)𝑃conditionalsubscript𝑥𝑡subscript𝑥0P(x_{t}|x_{0})italic_P ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) follows N(β(t)x0,σ(t)ID)𝑁𝛽𝑡subscript𝑥0𝜎𝑡subscript𝐼𝐷N(\beta(t)x_{0},\sigma(t)I_{D})italic_N ( italic_β ( italic_t ) italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ ( italic_t ) italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ), where β(t)=exp(0tw(s)ds/2)𝛽𝑡superscriptsubscript0𝑡𝑤𝑠𝑠2\beta(t)=\exp(-\int_{0}^{t}w(s)\differential s/2)italic_β ( italic_t ) = roman_exp ( start_ARG - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_w ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s / 2 end_ARG ) and σ(t)=1β2(t)𝜎𝑡1superscript𝛽2𝑡\sigma(t)=1-\beta^{2}(t)italic_σ ( italic_t ) = 1 - italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ). In practice, (D.1) terminates at a large enough T𝑇Titalic_T such that PTsubscript𝑃𝑇P_{T}italic_P start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is close to N(0,ID)𝑁0subscript𝐼𝐷N(0,I_{D})italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ).

Backward Process.

We obtain the backward process yt:=xTtassignsubscript𝑦𝑡subscript𝑥𝑇𝑡y_{t}:=x_{T-t}italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT := italic_x start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT by reversing (D.1). The backward process satisfies

dyt=[12w(Tt)yt+w(Tt)logpTt(yt)]dt+w(Tt)d\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111Btsubscript𝑦𝑡delimited-[]12𝑤𝑇𝑡subscript𝑦𝑡𝑤𝑇𝑡subscript𝑝𝑇𝑡subscript𝑦𝑡𝑡𝑤𝑇𝑡\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111subscript𝐵𝑡\displaystyle\differential y_{t}=\left[\frac{1}{2}w(T-t)y_{t}+w(T-t)\nabla\log p% _{T-t}(y_{t})\right]\differential t+\sqrt{w(T-t)}\differential\macc@depth\char 1% \relax\frozen@everymath{\macc@group}\macc@set@skewchar\macc@nested@a 111{B}_{t}start_DIFFOP roman_d end_DIFFOP italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_w ( italic_T - italic_t ) italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_w ( italic_T - italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] start_DIFFOP roman_d end_DIFFOP italic_t + square-root start_ARG italic_w ( italic_T - italic_t ) end_ARG start_DIFFOP roman_d end_DIFFOP roman_Δ 111 italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

where the score function logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) is the gradient of log probability density function of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and B¯tsubscript¯𝐵𝑡\bar{B}_{t}over¯ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a reversed Brownian motion. However, logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) and PTsubscript𝑃𝑇P_{T}italic_P start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT are both unknown in (D.1). To resolve this, we use a score estimator sW(,t)subscript𝑠𝑊𝑡s_{W}(\cdot,t)italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , italic_t ) to replace logpt()subscript𝑝𝑡\nabla\log p_{t}(\cdot)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ), where sW(,t)subscript𝑠𝑊𝑡s_{W}(\cdot,t)italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , italic_t ) is usually a neural network with parameters W𝑊Witalic_W. Secondly, we replace PTsubscript𝑃𝑇P_{T}italic_P start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT by the standard Gaussian distribution. Consequently, we obtain the following SDE

dyt=[12w(Tt)yt+w(Tt)sW(yt,Tt)]dt+w(Tt)dB¯t,y0N(0,ID).formulae-sequencesubscript𝑦𝑡delimited-[]12𝑤𝑇𝑡subscript𝑦𝑡𝑤𝑇𝑡subscript𝑠𝑊subscript𝑦𝑡𝑇𝑡𝑡𝑤𝑇𝑡subscript¯𝐵𝑡similar-tosubscript𝑦0𝑁0subscript𝐼𝐷\displaystyle\differential y_{t}=\left[\frac{1}{2}w(T-t)y_{t}+w(T-t)s_{W}(y_{t% },T-t)\right]\differential t+\sqrt{w(T-t)}\differential\bar{B}_{t},\leavevmode% \nobreak\ y_{0}\sim N(0,I_{D}).start_DIFFOP roman_d end_DIFFOP italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_w ( italic_T - italic_t ) italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_w ( italic_T - italic_t ) italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_T - italic_t ) ] start_DIFFOP roman_d end_DIFFOP italic_t + square-root start_ARG italic_w ( italic_T - italic_t ) end_ARG start_DIFFOP roman_d end_DIFFOP over¯ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) . (D.2)

In practice, we use discrete schemes of (D.2) to generate data, following (Song and Ermon, 2019). We use μ>0𝜇0\mu>0italic_μ > 0 to denote the discretization step size, and for t[kη,(k+1)μ]𝑡𝑘𝜂𝑘1𝜇t\in[k\eta,(k+1)\mu]italic_t ∈ [ italic_k italic_η , ( italic_k + 1 ) italic_μ ], we have

dyt=[12w(Tt)ykμ+w(Tt)sW(ykμ,Tkμ)]dt+w(Tt)dB¯t.superscriptsubscript𝑦𝑡delimited-[]12𝑤𝑇𝑡superscriptsubscript𝑦𝑘𝜇𝑤𝑇𝑡subscript𝑠𝑊superscriptsubscript𝑦𝑘𝜇𝑇𝑘𝜇𝑡𝑤𝑇𝑡subscript¯𝐵𝑡\displaystyle\differential y_{t}^{\leftarrow}=\left[\frac{1}{2}w(T-t)y_{k\mu}^% {\leftarrow}+w(T-t)s_{W}(y_{k\mu}^{\leftarrow},T-k\mu)\right]\differential t+% \sqrt{w(T-t)}\differential\bar{B}_{t}.start_DIFFOP roman_d end_DIFFOP italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT = [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_w ( italic_T - italic_t ) italic_y start_POSTSUBSCRIPT italic_k italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT + italic_w ( italic_T - italic_t ) italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_k italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT , italic_T - italic_k italic_μ ) ] start_DIFFOP roman_d end_DIFFOP italic_t + square-root start_ARG italic_w ( italic_T - italic_t ) end_ARG start_DIFFOP roman_d end_DIFFOP over¯ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . (D.3)

D.2 Proof of Lemma 2.1

Here we restate the proof of (Chen et al., 2023a, Lemma 1) for completeness.

Proof.

Recall x=Bh𝑥𝐵x=Bhitalic_x = italic_B italic_h by Assumption 2.1 with xD𝑥superscript𝐷x\in\mathbb{R}^{D}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, BD×d0𝐵superscript𝐷subscript𝑑0B\in\mathbb{R}^{D\times d_{0}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and hd0superscriptsubscript𝑑0h\in\mathbb{R}^{d_{0}}italic_h ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

By the forward process (D.1), we have

pt(x)=ψt(xBh)ph(h)dh,subscript𝑝𝑡𝑥subscript𝜓𝑡conditional𝑥𝐵subscript𝑝\displaystyle p_{t}(x)=\int\psi_{t}(x\mid Bh)p_{h}(h)\differential h,italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h , (D.4)

where

ψt(xBh)=[2πh(t)]D/2exp(β(t)Bhx222σ(t)),subscript𝜓𝑡conditional𝑥𝐵superscriptdelimited-[]2𝜋𝑡𝐷2superscriptsubscriptnorm𝛽𝑡𝐵𝑥222𝜎𝑡\displaystyle\psi_{t}(x\mid Bh)=[2\pi h(t)]^{-D/2}\exp\left(-\frac{\norm{\beta% (t)Bh-x}_{2}^{2}}{2\sigma(t)}\right),italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) = [ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ) , (D.5)

is the Gaussian transition kernel.

Then we write the score function as

logpt(x)subscript𝑝𝑡𝑥\displaystyle\nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) =pt(x)pt(x)absentsubscript𝑝𝑡𝑥subscript𝑝𝑡𝑥\displaystyle=\frac{\nabla p_{t}(x)}{p_{t}(x)}= divide start_ARG ∇ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG (By log-derivative)
=ψt(xBh)ph(h)dhψt(xBh)ph(h)dhabsentsubscript𝜓𝑡conditional𝑥𝐵subscript𝑝subscript𝜓𝑡conditional𝑥𝐵subscript𝑝\displaystyle=\frac{\nabla\int\psi_{t}(x\mid Bh)p_{h}(h)\differential h}{\int% \psi_{t}(x\mid Bh)p_{h}(h)\differential h}= divide start_ARG ∇ ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG (By pluging in pt(x)subscript𝑝𝑡𝑥p_{t}(x)italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ))
=ψt(xBh)ph(h)dhψt(xBh)ph(h)dh,absentsubscript𝜓𝑡conditional𝑥𝐵subscript𝑝subscript𝜓𝑡conditional𝑥𝐵subscript𝑝\displaystyle=\frac{\int\nabla\psi_{t}(x\mid Bh)p_{h}(h)\differential h}{\int% \psi_{t}(x\mid Bh)p_{h}(h)\differential h},= divide start_ARG ∫ ∇ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG , (By interchanging \int with \nabla)

where the last equality holds since ψt(xBh)subscript𝜓𝑡conditional𝑥𝐵\psi_{t}(x\mid Bh)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) is continuously differentiable in x𝑥xitalic_x.

Plugging (D.5) into ((By log-derivative)), we have

logpt(x)subscript𝑝𝑡𝑥\displaystyle\leavevmode\nobreak\ \nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x )
=\displaystyle== [2πh(t)]D/2ψt(xBh)ph(h)dh1σ(t)(β(t)Bhx)exp(β(t)Bhx222σ(t))ph(h)dh.superscriptdelimited-[]2𝜋𝑡𝐷2subscript𝜓𝑡conditional𝑥𝐵subscript𝑝1𝜎𝑡𝛽𝑡𝐵𝑥superscriptsubscriptnorm𝛽𝑡𝐵𝑥222𝜎𝑡subscript𝑝\displaystyle\leavevmode\nobreak\ \frac{[2\pi h(t)]^{-D/2}}{\int\psi_{t}(x\mid Bh% )p_{h}(h)\differential h}\int\frac{1}{\sigma(t)}\left(\beta(t)Bh-x\right)\exp% \left(-\frac{\norm{\beta(t)Bh-x}_{2}^{2}}{2\sigma(t)}\right)p_{h}(h)% \differential h.divide start_ARG [ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_β ( italic_t ) italic_B italic_h - italic_x ) roman_exp ( - divide start_ARG ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h .

We them decompose above score function by projecting of x𝑥xitalic_x into Span(B)Span𝐵{\rm Span}(B)roman_Span ( italic_B ), i.e., replacing x𝑥-x- italic_x with BBx(IDBB)x𝐵superscript𝐵top𝑥subscript𝐼𝐷𝐵superscript𝐵top𝑥-BB^{\top}x-(I_{D}-BB^{\top})x- italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x - ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x:

logpt(x)subscript𝑝𝑡𝑥\displaystyle\leavevmode\nobreak\ \nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x )
=\displaystyle== [2πh(t)]D/2ψt(xBh)ph(h)dhsuperscriptdelimited-[]2𝜋𝑡𝐷2subscript𝜓𝑡conditional𝑥𝐵subscript𝑝\displaystyle\leavevmode\nobreak\ \frac{[2\pi h(t)]^{-D/2}}{\int\psi_{t}(x\mid Bh% )p_{h}(h)\differential h}divide start_ARG [ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG
1σ(t)[(β(t)BhBBx)(IDBB)x]exp(β(t)Bhx222σ(t))ph(h)dh.\displaystyle\leavevmode\nobreak\ \cdot\int\frac{1}{\sigma(t)}\Bigg{[}\left(% \beta(t)Bh-BB^{\top}x\right)-\left(I_{D}-BB^{\top}\right)x\Bigg{]}\exp\left(-% \frac{\norm{\beta(t)Bh-x}_{2}^{2}}{2\sigma(t)}\right)p_{h}(h)\differential h.⋅ ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG [ ( italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) - ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ] roman_exp ( - divide start_ARG ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h .

Absorbing the factor of [2πh(t)]D/2superscriptdelimited-[]2𝜋𝑡𝐷2[2\pi h(t)]^{-D/2}[ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT into the Gaussian kernel ψt(xBh)subscript𝜓𝑡conditional𝑥𝐵\psi_{t}(x\mid Bh)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ), we have

logpt(x)subscript𝑝𝑡𝑥\displaystyle\leavevmode\nobreak\ \nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x )
=\displaystyle== [2πh(t)]D/2ψt(xBh)ph(h)dh1σ(t)(β(t)BhBBx)exp(β(t)Bhx222σ(t))ph(h)dhsuperscriptdelimited-[]2𝜋𝑡𝐷2subscript𝜓𝑡conditional𝑥𝐵subscript𝑝1𝜎𝑡𝛽𝑡𝐵𝐵superscript𝐵top𝑥superscriptsubscriptnorm𝛽𝑡𝐵𝑥222𝜎𝑡subscript𝑝\displaystyle\leavevmode\nobreak\ \frac{[2\pi h(t)]^{-D/2}}{\int\psi_{t}(x\mid Bh% )p_{h}(h)\differential h}\int\frac{1}{\sigma(t)}\left(\beta(t)Bh-BB^{\top}x% \right)\exp\left(-\frac{\norm{\beta(t)Bh-x}_{2}^{2}}{2\sigma(t)}\right)p_{h}(h% )\differential hdivide start_ARG [ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) roman_exp ( - divide start_ARG ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h
1ψt(x|Bh)ph(h)dh(1σ(t)(IDBB)x)ψt(xBh)ph(h)dh1subscript𝜓𝑡conditional𝑥𝐵subscript𝑝1𝜎𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥subscript𝜓𝑡conditional𝑥𝐵subscript𝑝\displaystyle\leavevmode\nobreak\ -\frac{1}{\int\psi_{t}(x|Bh)p_{h}(h)% \differential h}\left(\frac{1}{\sigma(t)}\left(I_{D}-BB^{\top}\right)x\right)% \int\psi_{t}(x\mid Bh)p_{h}(h)\differential h- divide start_ARG 1 end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x | italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h
=\displaystyle== 1ψt(xBh)ph(h)dh1σ(t)(β(t)BhBBx)ψt(xBh)ph(h)dhs+1σ(t)(IDBB)xs.subscript1subscript𝜓𝑡conditional𝑥𝐵subscript𝑝1𝜎𝑡𝛽𝑡𝐵𝐵superscript𝐵top𝑥subscript𝜓𝑡conditional𝑥𝐵subscript𝑝absentsubscript𝑠subscript1𝜎𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥absentsubscript𝑠\displaystyle\leavevmode\nobreak\ \underbrace{\frac{1}{\int\psi_{t}(x\mid Bh)p% _{h}(h)\differential h}\int\frac{1}{\sigma(t)}\left(\beta(t)Bh-BB^{\top}x% \right)\psi_{t}(x\mid Bh)p_{h}(h)\differential h}_{\coloneqq s_{+}}\underbrace% {-\frac{1}{\sigma(t)}\left(I_{D}-BB^{\top}\right)x}_{\coloneqq s_{-}}.under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG start_POSTSUBSCRIPT ≔ italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG - divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x end_ARG start_POSTSUBSCRIPT ≔ italic_s start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_POSTSUBSCRIPT .

To further simplify s+subscript𝑠s_{+}italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT, we decompose ψt(xBh)subscript𝜓𝑡conditional𝑥𝐵\psi_{t}(x\mid Bh)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) as

ψt(xBh)subscript𝜓𝑡conditional𝑥𝐵\displaystyle\leavevmode\nobreak\ \psi_{t}(x\mid Bh)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h )
=\displaystyle== [2πh(t)]D/2exp(12σ(t)β(t)Bhx22)superscriptdelimited-[]2𝜋𝑡𝐷212𝜎𝑡superscriptsubscriptnorm𝛽𝑡𝐵𝑥22\displaystyle\leavevmode\nobreak\ [2\pi h(t)]^{-D/2}\exp\left(-\frac{1}{2% \sigma(t)}\norm{\beta(t)Bh-x}_{2}^{2}\right)[ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=\displaystyle== [2πh(t)]D/2exp(12σ(t)β(t)BhBBx(IDBB)x22)superscriptdelimited-[]2𝜋𝑡𝐷212𝜎𝑡superscriptsubscriptnorm𝛽𝑡𝐵𝐵superscript𝐵top𝑥subscript𝐼𝐷𝐵superscript𝐵top𝑥22\displaystyle\leavevmode\nobreak\ [2\pi h(t)]^{-D/2}\exp\left(-\frac{1}{2% \sigma(t)}\norm{\beta(t)Bh-BB^{\top}x-\left(I_{D}-BB^{\top}\right)x}_{2}^{2}\right)[ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x - ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=\displaystyle== [2πh(t)]D/2superscriptdelimited-[]2𝜋𝑡𝐷2\displaystyle\leavevmode\nobreak\ [2\pi h(t)]^{-D/2}[ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT
=\displaystyle== [2πh(t)]D/2exp(12σ(t)(β(t)BhBBx22+(IDBB)x22))superscriptdelimited-[]2𝜋𝑡𝐷212𝜎𝑡superscriptsubscriptnorm𝛽𝑡𝐵𝐵superscript𝐵top𝑥22superscriptsubscriptnormsubscript𝐼𝐷𝐵superscript𝐵top𝑥22\displaystyle\leavevmode\nobreak\ [2\pi h(t)]^{-D/2}\exp\left(-\frac{1}{2% \sigma(t)}\left(\norm{\beta(t)Bh-BB^{\top}x}_{2}^{2}+\norm{\left(I_{D}-BB^{% \top}\right)x}_{2}^{2}\right)\right)[ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_D / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ( ∥ start_ARG italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ start_ARG ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) (B(β(t)hBx)𝐵𝛽𝑡superscript𝐵top𝑥B(\beta(t)h-B^{\top}x)italic_B ( italic_β ( italic_t ) italic_h - italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) is in Span(B)Span𝐵{\rm Span}(B)roman_Span ( italic_B ) while (IDBB)xsubscript𝐼𝐷𝐵superscript𝐵top𝑥(I_{D}-BB^{\top})x( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x is orthogonal to Span(B)Span𝐵{\rm Span}(B)roman_Span ( italic_B ))
=\displaystyle== [2πh(t)]d0/2exp(β(t)hBx222σ(t))ψt(Bxh)[2πh(t)](Dd0)/2exp((IDBB)x222σ(t))ψt((IDBB)x),subscriptsuperscriptdelimited-[]2𝜋𝑡subscript𝑑02superscriptsubscriptnorm𝛽𝑡superscript𝐵top𝑥222𝜎𝑡absentsubscript𝜓𝑡conditionalsuperscript𝐵top𝑥subscriptsuperscriptdelimited-[]2𝜋𝑡𝐷subscript𝑑02superscriptsubscriptnormsubscript𝐼𝐷𝐵superscript𝐵top𝑥222𝜎𝑡absentsubscript𝜓𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥\displaystyle\leavevmode\nobreak\ \underbrace{[2\pi h(t)]^{-d_{0}/2}\exp\left(% -\frac{\norm{\beta(t)h-B^{\top}x}_{2}^{2}}{2\sigma(t)}\right)}_{\coloneqq\psi_% {t}\left(B^{\top}x\mid h\right)}\cdot\underbrace{[2\pi h(t)]^{-(D-d_{0})/2}% \exp\left(-\frac{\norm{\left(I_{D}-BB^{\top}\right)x}_{2}^{2}}{2\sigma(t)}% \right)}_{\coloneqq\psi_{t}\left((I_{D}-BB^{\top})x\right)},under⏟ start_ARG [ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG ∥ start_ARG italic_β ( italic_t ) italic_h - italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ) end_ARG start_POSTSUBSCRIPT ≔ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) end_POSTSUBSCRIPT ⋅ under⏟ start_ARG [ 2 italic_π italic_h ( italic_t ) ] start_POSTSUPERSCRIPT - ( italic_D - italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG ∥ start_ARG ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ ( italic_t ) end_ARG ) end_ARG start_POSTSUBSCRIPT ≔ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) end_POSTSUBSCRIPT , (since B𝐵Bitalic_B has orthonormal columns)

where both ψt(Bxh)subscript𝜓𝑡conditionalsuperscript𝐵top𝑥\psi_{t}\left(B^{\top}x\mid h\right)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) and ψt((IDBB)x)subscript𝜓𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥\psi_{t}\left((I_{D}-BB^{\top})x\right)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) are Gaussian.

Plugging ψt(xBh)=ψt(Bxh)ψt((IDBB)x)subscript𝜓𝑡conditional𝑥𝐵subscript𝜓𝑡conditionalsuperscript𝐵top𝑥subscript𝜓𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥\psi_{t}(x\mid Bh)=\psi_{t}\left(B^{\top}x\mid h\right)\psi_{t}\left((I_{D}-BB% ^{\top})x\right)italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ∣ italic_B italic_h ) = italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) into s+subscript𝑠s_{+}italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT, we obtain

s+(x,t)subscript𝑠𝑥𝑡\displaystyle s_{+}(x,t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_x , italic_t ) =C1σ(t)(β(t)BhBBx)ψt(Bxh)ψt((IDBB)x)ph(h)dhabsent𝐶1𝜎𝑡𝛽𝑡𝐵𝐵superscript𝐵top𝑥subscript𝜓𝑡conditionalsuperscript𝐵top𝑥subscript𝜓𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥subscript𝑝\displaystyle=C\int\frac{1}{\sigma(t)}\left(\beta(t)Bh-BB^{\top}x\right)\psi_{% t}(B^{\top}x\mid h)\psi_{t}((I_{D}-BB^{\top})x)p_{h}(h)\differential h= italic_C ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h
=Cψt((IDBB)x)1σ(t)(β(t)BhBBx)ψt(Bxh)ph(h)dhabsent𝐶subscript𝜓𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥1𝜎𝑡𝛽𝑡𝐵𝐵superscript𝐵top𝑥subscript𝜓𝑡conditionalsuperscript𝐵top𝑥subscript𝑝\displaystyle=C\psi_{t}((I_{D}-BB^{\top})x)\int\frac{1}{\sigma(t)}\left(\beta(% t)Bh-BB^{\top}x\right)\psi_{t}(B^{\top}x\mid h)p_{h}(h)\differential h= italic_C italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h
=1ψt(Bxh)ph(h)dh1σ(t)(β(t)BhBBx)ψt(Bxh)ph(h)dh,absent1subscript𝜓𝑡conditionalsuperscript𝐵top𝑥superscriptsubscript𝑝superscriptsuperscriptsuperscript1𝜎𝑡𝛽𝑡𝐵𝐵superscript𝐵top𝑥subscript𝜓𝑡conditionalsuperscript𝐵top𝑥subscript𝑝\displaystyle=\frac{1}{\int\psi_{t}(B^{\top}x\mid h^{\prime})p_{h^{\prime}}(h^% {\prime})\differential h^{\prime}}\int\frac{1}{\sigma(t)}\left(\beta(t)Bh-BB^{% \top}x\right)\psi_{t}(B^{\top}x\mid h)p_{h}(h)\differential h,= divide start_ARG 1 end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_p start_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG ∫ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_β ( italic_t ) italic_B italic_h - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h ,

where C[ψt((IDBB)x)ψt(Bxh)ph(h)dh]1𝐶superscriptdelimited-[]subscript𝜓𝑡subscript𝐼𝐷𝐵superscript𝐵top𝑥subscript𝜓𝑡conditionalsuperscript𝐵top𝑥superscriptsubscript𝑝superscriptsuperscriptsuperscript1C\coloneqq[\psi_{t}((I_{D}-BB^{\top})x)\int\psi_{t}(B^{\top}x\mid h^{\prime})p% _{h^{\prime}}(h^{\prime})\differential h^{\prime}]^{-1}italic_C ≔ [ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ( italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_x ) ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_p start_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT.

Notably, s+subscript𝑠s_{+}italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT depends only on the projected data Bxsuperscript𝐵top𝑥B^{\top}xitalic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x. Therefore, we are able to replace s+(x,t)subscript𝑠𝑥𝑡s_{+}(x,t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_x , italic_t ) with s+(Bx,t)subscript𝑠superscript𝐵top𝑥𝑡s_{+}(B^{\top}x,t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ). The benefit is that the dimension d0subscript𝑑0d_{0}italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT of the first input in s+(Bx,t)subscript𝑠superscript𝐵top𝑥𝑡s_{+}(B^{\top}x,t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) is much smaller.

Lastly, by denoting h¯=Bx¯superscript𝐵top𝑥\bar{h}=B^{\top}xover¯ start_ARG italic_h end_ARG = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x such that h¯ψt(h¯h)=(β(t)hh¯)ψt(Bxh)/σ(t)subscript¯subscript𝜓𝑡conditional¯𝛽𝑡¯subscript𝜓𝑡conditionalsuperscript𝐵top𝑥𝜎𝑡\nabla_{\bar{h}}\psi_{t}(\bar{h}\mid h)=(\beta(t)h-\bar{h})\psi_{t}(B^{\top}x% \mid h)/\sigma(t)∇ start_POSTSUBSCRIPT over¯ start_ARG italic_h end_ARG end_POSTSUBSCRIPT italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG ∣ italic_h ) = ( italic_β ( italic_t ) italic_h - over¯ start_ARG italic_h end_ARG ) italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ∣ italic_h ) / italic_σ ( italic_t ), we arrive at

s+(Bx,t)subscript𝑠superscript𝐵top𝑥𝑡\displaystyle s_{+}(B^{\top}x,t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) =Bh¯ψt(h¯h)ph(h)ψt(h¯h)ph(h)dhdhabsent𝐵subscript¯subscript𝜓𝑡conditional¯subscript𝑝subscript𝜓𝑡conditional¯superscriptsubscript𝑝superscriptsuperscriptsuperscript\displaystyle=B\int\frac{\nabla_{\bar{h}}\psi_{t}(\bar{h}\mid h)p_{h}(h)}{\int% \psi_{t}(\bar{h}\mid h^{\prime})p_{h^{\prime}}(h^{\prime})\differential h^{% \prime}}\differential h= italic_B ∫ divide start_ARG ∇ start_POSTSUBSCRIPT over¯ start_ARG italic_h end_ARG end_POSTSUBSCRIPT italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG ∣ italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG ∣ italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_p start_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_DIFFOP roman_d end_DIFFOP italic_h
=Blogpth(Bx).absent𝐵superscriptsubscript𝑝𝑡superscript𝐵top𝑥\displaystyle=B\nabla\log p_{t}^{h}(B^{\top}x).= italic_B ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x ) . (pth(h¯)ψt(h¯|h)ph(h)dhsuperscriptsubscript𝑝𝑡¯subscript𝜓𝑡conditional¯subscript𝑝p_{t}^{h}(\bar{h})\coloneqq\int\psi_{t}(\bar{h}|h)p_{h}(h)\differential hitalic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( over¯ start_ARG italic_h end_ARG ) ≔ ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG | italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h)

This completes the proof. ∎

D.3 Preliminaries: Strong Exponential Time Hypothesis (SETH) and Tensor Trick

Here we present the ideas we built upon for Section 4.

Strong Exponential Time Hypothesis (SETH). Impagliazzo and Paturi (2001) introduce the Strong Exponential Time Hypothesis (SETH) as a stronger form of the 𝙿𝙽𝙿𝙿𝙽𝙿\mathtt{P}\neq\mathtt{NP}typewriter_P ≠ typewriter_NP conjecture. It suggests that our current best 𝚂𝙰𝚃𝚂𝙰𝚃\mathtt{SAT}typewriter_SAT algorithms are optimal and is a popular conjecture for proving fine-grained lower bounds for a wide variety of algorithmic problems (Cygan et al., 2016; Williams, 2018).

Hypothesis 1 (SETH).

For every ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, there is a positive integer k3𝑘3k\geq 3italic_k ≥ 3 such that k𝑘kitalic_k-𝚂𝙰𝚃𝚂𝙰𝚃\mathtt{SAT}typewriter_SAT on formulas with n𝑛nitalic_n variables cannot be solved in 𝒪(2(1ϵ)n)𝒪superscript21italic-ϵ𝑛\mathcal{O}(2^{(1-\epsilon)n})caligraphic_O ( 2 start_POSTSUPERSCRIPT ( 1 - italic_ϵ ) italic_n end_POSTSUPERSCRIPT ) time, even by a randomized algorithm.

Tensor Trick for Computing Gradients. The tensor trick (Diao et al., 2019, 2018) is an instrument to compute complicated gradients in a clean and tractable fashion. We start with some definitions.

Definition D.1 (Vectorization).

For any matrix XL×d𝑋superscript𝐿𝑑X\in\mathbb{R}^{L\times d}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT, we define X¯vec(X)Ld¯𝑋vec𝑋superscript𝐿𝑑\underline{X}\coloneqq\operatorname{vec}{(X)}\in\mathbb{R}^{Ld}under¯ start_ARG italic_X end_ARG ≔ roman_vec ( italic_X ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L italic_d end_POSTSUPERSCRIPT such that Xi,j=X¯(i1)d+jsubscript𝑋𝑖𝑗subscript¯𝑋𝑖1𝑑𝑗X_{i,j}=\underline{X}_{(i-1)d+j}italic_X start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = under¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT ( italic_i - 1 ) italic_d + italic_j end_POSTSUBSCRIPT for all i[L]𝑖delimited-[]𝐿i\in[L]italic_i ∈ [ italic_L ] and j[d]𝑗delimited-[]𝑑j\in[d]italic_j ∈ [ italic_d ].

Definition D.2 (Matrixization).

For any vector X¯Ld¯𝑋superscript𝐿𝑑\underline{X}\in\mathbb{R}^{Ld}under¯ start_ARG italic_X end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_L italic_d end_POSTSUPERSCRIPT, we define mat(X¯)=Xmat¯𝑋𝑋\mathrm{mat}(\underline{X})=Xroman_mat ( under¯ start_ARG italic_X end_ARG ) = italic_X such that Xi,j=mat(X¯)X¯(i1)d+jsubscript𝑋𝑖𝑗mat¯𝑋subscript¯𝑋𝑖1𝑑𝑗X_{i,j}=\mathrm{mat}(\underline{X})\coloneqq\underline{X}_{(i-1)d+j}italic_X start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = roman_mat ( under¯ start_ARG italic_X end_ARG ) ≔ under¯ start_ARG italic_X end_ARG start_POSTSUBSCRIPT ( italic_i - 1 ) italic_d + italic_j end_POSTSUBSCRIPT for all i[L]𝑖delimited-[]𝐿i\in[L]italic_i ∈ [ italic_L ] and j[d]𝑗delimited-[]𝑑j\in[d]italic_j ∈ [ italic_d ], namely mat()=vec1()matsuperscriptvec1\mathrm{mat}(\cdot)=\operatorname{vec}^{-1}(\cdot)roman_mat ( ⋅ ) = roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ⋅ ).

Definition D.3 (Kronecker Product).

Let ALa×da𝐴superscriptsubscript𝐿𝑎subscript𝑑𝑎A\in\mathbb{R}^{L_{a}\times d_{a}}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and BLb×db𝐵superscriptsubscript𝐿𝑏subscript𝑑𝑏B\in\mathbb{R}^{L_{b}\times d_{b}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. We define the Kronecker product of A𝐴Aitalic_A and B𝐵Bitalic_B as ABLaLb×dadbtensor-product𝐴𝐵superscriptsubscript𝐿𝑎subscript𝐿𝑏subscript𝑑𝑎subscript𝑑𝑏A\otimes B\in\mathbb{R}^{L_{a}L_{b}\times d_{a}d_{b}}italic_A ⊗ italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that (AB)(ia1)Lb+ib,(ja1)db+jbsubscripttensor-product𝐴𝐵subscript𝑖𝑎1subscript𝐿𝑏subscript𝑖𝑏subscript𝑗𝑎1subscript𝑑𝑏subscript𝑗𝑏(A\otimes B)_{(i_{a}-1)L_{b}+i_{b},(j_{a}-1)d_{b}+j_{b}}( italic_A ⊗ italic_B ) start_POSTSUBSCRIPT ( italic_i start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT - 1 ) italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT + italic_i start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , ( italic_j start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT - 1 ) italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT + italic_j start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT, is equal to Aia,jaBib,jbsubscript𝐴subscript𝑖𝑎subscript𝑗𝑎subscript𝐵subscript𝑖𝑏subscript𝑗𝑏A_{i_{a},j_{a}}B_{i_{b},j_{b}}italic_A start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT with ia[La],ja[da],ib[Lb],jb[db]formulae-sequencesubscript𝑖𝑎delimited-[]subscript𝐿𝑎formulae-sequencesubscript𝑗𝑎delimited-[]subscript𝑑𝑎formulae-sequencesubscript𝑖𝑏delimited-[]subscript𝐿𝑏subscript𝑗𝑏delimited-[]subscript𝑑𝑏i_{a}\in[L_{a}],j_{a}\in[d_{a}],i_{b}\in[L_{b}],j_{b}\in[d_{b}]italic_i start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ∈ [ italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ] , italic_j start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ∈ [ italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ] , italic_i start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ∈ [ italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ] , italic_j start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ∈ [ italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ].

Definition D.4 (Sub-Block of a Tensor).

For any ALa×da𝐴superscriptsubscript𝐿𝑎subscript𝑑𝑎A\in\mathbb{R}^{L_{a}\times d_{a}}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and BLb×db𝐵superscriptsubscript𝐿𝑏subscript𝑑𝑏B\in\mathbb{R}^{L_{b}\times d_{b}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, let 𝖠ABLaLb×dadb𝖠tensor-product𝐴𝐵superscriptsubscript𝐿𝑎subscript𝐿𝑏subscript𝑑𝑎subscript𝑑𝑏\operatorname{\mathsf{A}}\coloneqq A\otimes B\in\mathbb{R}^{L_{a}L_{b}\times d% _{a}d_{b}}sansserif_A ≔ italic_A ⊗ italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. For any j¯[La]¯𝑗delimited-[]subscript𝐿𝑎\underline{j}\in[L_{a}]under¯ start_ARG italic_j end_ARG ∈ [ italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ], we define 𝖠j¯Lb×dadbsubscript𝖠¯𝑗superscriptsubscript𝐿𝑏subscript𝑑𝑎subscript𝑑𝑏\operatorname{\mathsf{A}}_{\underline{j}}\in\mathbb{R}^{L_{b}\times d_{a}d_{b}}sansserif_A start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be the j¯¯𝑗\underline{j}under¯ start_ARG italic_j end_ARG-th Lb×dadbsubscript𝐿𝑏subscript𝑑𝑎subscript𝑑𝑏L_{b}\times d_{a}d_{b}italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT sub-block of 𝖠𝖠\operatorname{\mathsf{A}}sansserif_A.

Lemma D.1 (Tensor Trick (Diao et al., 2019, 2018)).

For any ALa×da𝐴superscriptsubscript𝐿𝑎subscript𝑑𝑎A\in\mathbb{R}^{L_{a}\times d_{a}}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, BLb×db𝐵superscriptsubscript𝐿𝑏subscript𝑑𝑏B\in\mathbb{R}^{L_{b}\times d_{b}}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and Xda×db𝑋superscriptsubscript𝑑𝑎subscript𝑑𝑏X\in\mathbb{R}^{d_{a}\times d_{b}}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, it holds vec(AXB)=(AB)X¯LaLbvecsuperscript𝐴top𝑋𝐵tensor-productsuperscript𝐴topsuperscript𝐵top¯𝑋superscriptsubscript𝐿𝑎subscript𝐿𝑏\operatorname{vec}\left(A^{\top}XB\right)=(A^{\top}\otimes B^{\top})\underline% {X}\in\mathbb{R}^{L_{a}L_{b}}roman_vec ( italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X italic_B ) = ( italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⊗ italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) under¯ start_ARG italic_X end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

To showcase the tensor trick, let’s consider a (single data point) attention following (Gao et al., 2023b, c). Setting Ddiag(exp(X𝖳WK𝖳WQX)𝟙L)𝐷diagsuperscript𝑋𝖳superscriptsubscript𝑊𝐾𝖳subscript𝑊𝑄𝑋subscript1𝐿D\coloneqq\mathop{\rm{diag}}\left(\exp(X^{\mathsf{T}}W_{K}^{\mathsf{T}}W_{Q}X)% \mathds{1}_{L}\right)italic_D ≔ roman_diag ( roman_exp ( start_ARG italic_X start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_X end_ARG ) blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) and WWKWQ𝖳d×d𝑊subscript𝑊𝐾superscriptsubscript𝑊𝑄𝖳superscript𝑑𝑑W\coloneqq W_{K}W_{Q}^{\mathsf{T}}\in\mathbb{R}^{d\times d}italic_W ≔ italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT, we have

0WVd×dXd×LD1L×Lexp(X𝖳WX)L×LYd×L22.subscript0superscriptsubscriptnormsubscriptsubscript𝑊𝑉𝑑𝑑subscript𝑋absentsuperscript𝑑𝐿subscriptsuperscript𝐷1absentsuperscript𝐿𝐿subscriptsuperscript𝑋𝖳𝑊𝑋absentsuperscript𝐿𝐿subscript𝑌absentsuperscript𝑑𝐿22\displaystyle\mathcal{L}_{0}\coloneqq\big{\|}\underbrace{W_{V}}_{d\times d}% \underbrace{X}_{\in\mathbb{R}^{d\times L}}\underbrace{D^{-1}}_{\in\mathbb{R}^{% L\times L}}\underbrace{\exp{X^{\mathsf{T}}WX}}_{\in\mathbb{R}^{L\times L}}-% \underbrace{Y}_{\in\mathbb{R}^{d\times L}}\big{\|}_{2}^{2}.caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≔ ∥ under⏟ start_ARG italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_d end_POSTSUBSCRIPT under⏟ start_ARG italic_X end_ARG start_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG roman_exp ( start_ARG italic_X start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_W italic_X end_ARG ) end_ARG start_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - under⏟ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (D.6)
Proposition D.1 (Definition 4.7 of (Gao et al., 2023b)).

By Definition D.3 and Definition D.4, we identify Dj¯,j¯exp(𝖠j¯W¯),𝟙Lsubscript𝐷¯𝑗¯𝑗expectationsubscript𝖠¯𝑗¯𝑊subscript1𝐿D_{\underline{j},\underline{j}}\coloneqq\Braket{\exp(\operatorname{\mathsf{A}}% _{\underline{j}}\underline{W}),\mathds{1}_{L}}\in\mathbb{R}italic_D start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG , under¯ start_ARG italic_j end_ARG end_POSTSUBSCRIPT ≔ ⟨ start_ARG roman_exp ( start_ARG sansserif_A start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG end_ARG ) , blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG ⟩ ∈ blackboard_R for all j¯[L]¯𝑗delimited-[]𝐿\underline{j}\in[L]under¯ start_ARG italic_j end_ARG ∈ [ italic_L ], with 𝖠XXL2×d2𝖠tensor-product𝑋𝑋superscriptsuperscript𝐿2superscript𝑑2\operatorname{\mathsf{A}}\coloneqq X\otimes X\in\mathbb{R}^{L^{2}\times d^{2}}sansserif_A ≔ italic_X ⊗ italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT and W¯d2¯𝑊superscriptsuperscript𝑑2\underline{W}\in\mathbb{R}^{d^{2}}under¯ start_ARG italic_W end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. Therefore, for each j¯[L]¯𝑗delimited-[]𝐿\underline{j}\in[L]under¯ start_ARG italic_j end_ARG ∈ [ italic_L ] and i¯[d]¯𝑖delimited-[]𝑑\underline{i}\in[d]under¯ start_ARG italic_i end_ARG ∈ [ italic_d ], it holds 0=j¯=1Li¯=1d12(Dj¯,j¯1exp(𝖠j¯W¯),XWV[,i¯]Yj¯,i¯)2subscript0superscriptsubscript¯𝑗1𝐿superscriptsubscript¯𝑖1𝑑12superscriptexpectationsubscriptsuperscript𝐷1¯𝑗¯𝑗subscript𝖠¯𝑗¯𝑊𝑋subscript𝑊𝑉¯𝑖subscript𝑌¯𝑗¯𝑖2\mathcal{L}_{0}=\sum_{\underline{j}=1}^{L}\sum_{\underline{i}=1}^{d}{\frac{1}{% 2}}\left(\Braket{D^{-1}_{\underline{j},\underline{j}}\exp(\operatorname{% \mathsf{A}}_{\underline{j}}\underline{W}),XW_{V}[\cdot,\underline{i}]}-Y_{% \underline{j},\underline{i}}\right)^{2}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT under¯ start_ARG italic_i end_ARG = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( ⟨ start_ARG italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG , under¯ start_ARG italic_j end_ARG end_POSTSUBSCRIPT roman_exp ( start_ARG sansserif_A start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG end_ARG ) , italic_X italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT [ ⋅ , under¯ start_ARG italic_i end_ARG ] end_ARG ⟩ - italic_Y start_POSTSUBSCRIPT under¯ start_ARG italic_j end_ARG , under¯ start_ARG italic_i end_ARG end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

The elegance of Proposition D.1 emerges when we vectorize the weights into vectors W¯,W¯V¯𝑊subscript¯𝑊𝑉\underline{W},\underline{W}_{V}under¯ start_ARG italic_W end_ARG , under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT, making the gradient computations (e.g., d0/W¯subscript0¯𝑊\nicefrac{{\differential\mathcal{L}_{0}}}{{\underline{W}}}/ start_ARG start_DIFFOP roman_d end_DIFFOP caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG under¯ start_ARG italic_W end_ARG end_ARG and d0/W¯Vsubscript0subscript¯𝑊𝑉\nicefrac{{\differential\mathcal{L}_{0}}}{{\underline{W}_{V}}}/ start_ARG start_DIFFOP roman_d end_DIFFOP caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_ARG) more tractable by avoiding complex matrix or tensor derivatives. This approach systematically simplifies the handling of chain-rule terms in the gradient computation of losses like 0subscript0\mathcal{L}_{0}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT.

Appendix E More Background and Auxiliary Lemmas: Universal Approximation of Transformers via Piecewise Approximation

Here, we review the universal approximation of Transformers following (Yun et al., 2020). Our goal is to reproduce the results of (Yun et al., 2020) and use or modify them as auxiliary lemmas for proofs of Section 3 (i.e., Appendix F.)

We start with their central result, and the rest of the section aims to prove it.

Lemma E.1 (Universal Approximation of Transformers, Theorem 3 of (Yun et al., 2020)).

Let ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. For any given compact-supported continuous function f:d×Ld×L:𝑓superscript𝑑𝐿superscript𝑑𝐿f:\mathbb{R}^{d\times L}\to\mathbb{R}^{d\times L}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, there exists a Transformer network f𝒯𝒯p2,1,4subscript𝑓𝒯superscriptsubscript𝒯𝑝214f_{\mathcal{T}}\in\mathcal{T}_{p}^{2,1,4}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT such that we have

(f𝒯(X)f(X)F2dX)1/2ϵ.superscriptsuperscriptsubscriptnormsubscript𝑓𝒯𝑋𝑓𝑋𝐹2𝑋12italic-ϵ\displaystyle\left(\int\norm{f_{\mathcal{T}}(X)-f(X)}_{F}^{2}\differential X% \right)^{1/2}\leq\epsilon.( ∫ ∥ start_ARG italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_X ) - italic_f ( italic_X ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_X ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ≤ italic_ϵ .
Proof Overview.

We use the following proof strategy:

  • Step 1. We show that piecewise-constant function is able to approximate compact-supported continuous function in Section E.1.

  • Step 2. We define modified self-attention and feed-forward layers to construct the modified transformer. We show that modified transformer is able to approximate piecewise-constant function in Section E.2.

  • Step 3. We show that the modified transformer is able to approximate normal transformer in Section E.3.

Below, we provide details of Step 1. in Section E.1, Step 2. in Section E.2 and Step 3. in Section E.3. Then we give a summary of our results in Section E.4.

E.1 Piecewise-constant Function Approximates Compact-Supported Continuous Function

In this subsection, we show that piecewise-constant function is able to approximate compact-supported continuous function.

We start with the definition of the compact-supported continuous functions of interest.

Assumption E.1.

Without loss of generality, we assume that the target function in discussion is supported on [0,1]d×Lsuperscript01𝑑𝐿[0,1]^{d\times L}[ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. We denote the set of [0,1]d×Lsuperscript01𝑑𝐿[0,1]^{d\times L}[ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT-supported continuous functions as \mathcal{F}caligraphic_F.

We introduce the notion of grid and cube for the compact support [0,1]d×Lsuperscript01𝑑𝐿[0,1]^{d\times L}[ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT.

Definition E.1 (Grid and Cube with Width δ𝛿\deltaitalic_δ).

Given a grid width δ𝛿\deltaitalic_δ, let 𝒢δ{0,δ,,1δ}d×Lsubscript𝒢𝛿superscript0𝛿1𝛿𝑑𝐿\mathcal{G}_{\delta}\coloneqq\{0,\delta,\dots,1-\delta\}^{d\times L}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ≔ { 0 , italic_δ , … , 1 - italic_δ } start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT denote the set of grids within [0,1]d×Lsuperscript01𝑑𝐿[0,1]^{d\times L}[ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. For a grid point G=(Gj[d],k[L])𝒢δ𝐺subscript𝐺formulae-sequence𝑗delimited-[]𝑑𝑘delimited-[]𝐿subscript𝒢𝛿G=(G_{j\in[d],k\in[L]})\in\mathcal{G}_{\delta}italic_G = ( italic_G start_POSTSUBSCRIPT italic_j ∈ [ italic_d ] , italic_k ∈ [ italic_L ] end_POSTSUBSCRIPT ) ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we denote its associated cube as

𝒮G:=j=1dk=1L[Gj,k,Gj,k+δ)[0,1]d×L.\displaystyle\mathcal{S}_{G}:=\otimes_{j=1}^{d}\otimes_{k=1}^{L}[G_{j,k},G_{j,% k}+\delta)\subset[0,1]^{d\times L}.caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT := ⊗ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ⊗ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT [ italic_G start_POSTSUBSCRIPT italic_j , italic_k end_POSTSUBSCRIPT , italic_G start_POSTSUBSCRIPT italic_j , italic_k end_POSTSUBSCRIPT + italic_δ ) ⊂ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT .

We introduce the notion of piecewise-constant fucntion class w.r.t. the [0,1]d×Lsuperscript01𝑑𝐿[0,1]^{d\times L}[ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT-supported continuous function class \mathcal{F}caligraphic_F.

Definition E.2 (Piecewise-Constant Function Class).

Let fδsubscript𝑓𝛿f_{\delta}italic_f start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT denote the piesewise constant function of grid width δ𝛿\deltaitalic_δ, and 𝟙{}1\mathds{1}\{\cdot\}blackboard_1 { ⋅ } denote the indicator function. For each G𝒢δ𝐺subscript𝒢𝛿G\in\mathcal{G}_{\delta}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, and any matrix AGd×Lsubscript𝐴𝐺superscript𝑑𝐿A_{G}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, we define the piecewise-constant function class as

(δ){fδ:XG𝒢δAG𝟙{X𝒮G},AGd×L}.𝛿conditional-setsubscript𝑓𝛿formulae-sequence𝑋subscript𝐺subscript𝒢𝛿subscript𝐴𝐺1𝑋subscript𝒮𝐺subscript𝐴𝐺superscript𝑑𝐿\displaystyle\mathcal{F}(\delta)\coloneqq\left\{f_{\delta}:X\rightarrow\sum% \nolimits_{G\in\mathcal{G}_{\delta}}A_{G}\cdot\mathds{1}\{X\in\mathcal{S}_{G}% \},A_{G}\in\mathbb{R}^{d\times L}\right\}.caligraphic_F ( italic_δ ) ≔ { italic_f start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT : italic_X → ∑ start_POSTSUBSCRIPT italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ⋅ blackboard_1 { italic_X ∈ caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT } , italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT } . (E.1)

We recall that for a given sequence-to-sequence function f𝑓fitalic_f, we have

fL2:=(f(X)F2dX)1/2.assignsubscriptnorm𝑓superscript𝐿2superscriptsuperscriptsubscriptnorm𝑓𝑋𝐹2𝑋12\displaystyle\norm{f}_{L^{2}}:=\bigg{(}\int\norm{f(X)}_{F}^{2}\differential X% \bigg{)}^{1/2}.∥ start_ARG italic_f end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT := ( ∫ ∥ start_ARG italic_f ( italic_X ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_X ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT .

We approximate the compact-supported function with piecewise-constant function with next lemma.

Lemma E.2.

(Lemma 8 of (Yun et al., 2020)) For any given f𝑓f\in\mathcal{F}italic_f ∈ caligraphic_F and ϵ/3>0italic-ϵ30\epsilon/3>0italic_ϵ / 3 > 0, we can find a δ>0superscript𝛿0\delta^{\star}>0italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT > 0 such that there exists a fδ(δ)subscript𝑓superscript𝛿superscript𝛿f_{\delta^{\star}}\in\mathcal{F}(\delta^{\star})italic_f start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∈ caligraphic_F ( italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) satisfying ffδL2ϵ/3subscriptnorm𝑓subscript𝑓superscript𝛿superscript𝐿2italic-ϵ3\norm{f-f_{\delta^{\star}}}_{L^{2}}\leq\epsilon/3∥ start_ARG italic_f - italic_f start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≤ italic_ϵ / 3.

Proof.

See Section E.5.2 for a detailed proof. ∎

E.2 Modified Transformer Approximates Piece-Wise Constant Function

In this subsection, we define modified self-attention and feed-forward layers to construct the modified transformers. We use the modified transformers to approximate piecewise-constant function.

Definition E.3 (Modified Transformer Networks).

The modification of transformer networks 𝒯¯pr,m,lsuperscriptsubscript¯𝒯𝑝𝑟𝑚𝑙\bar{\mathcal{T}}_{p}^{r,m,l}over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT includes two modifications from normal transformer networks 𝒯pr,m,lsuperscriptsubscript𝒯𝑝𝑟𝑚𝑙\mathcal{T}_{p}^{r,m,l}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT:

  • Modified attention layer: Replace SoftmaxSoftmax\mathop{\rm{Softmax}}roman_Softmax operator with HardmaxHardmax\mathop{\rm{Hardmax}}roman_Hardmax operator σH()subscript𝜎𝐻\sigma_{H}(\cdot)italic_σ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ( ⋅ ).

  • Modified feed-forward layer: Replace ReLU()ReLU{\rm ReLU(\cdot)}roman_ReLU ( ⋅ ) with activation function ζΨ𝜁Ψ\zeta\in\Psiitalic_ζ ∈ roman_Ψ. Here, ΨΨ\Psiroman_Ψ denotes the set of all piecewise linear functions with at most three pieces and at least one is constant.

We approximate (δ)𝛿\mathcal{F}(\delta)caligraphic_F ( italic_δ ) with this modified transformer networks 𝒯¯pr,m,lsuperscriptsubscript¯𝒯𝑝𝑟𝑚𝑙\bar{\mathcal{T}}_{p}^{r,m,l}over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT as the following.

Lemma E.3 (Modified from Proposition 4 of (Yun et al., 2020)).

For each fδ(δ)subscript𝑓𝛿𝛿f_{\delta}\in\mathcal{F}(\delta)italic_f start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∈ caligraphic_F ( italic_δ ), there exists a f𝒯,c𝒯¯p2,1,1subscript𝑓𝒯𝑐superscriptsubscript¯𝒯𝑝211f_{\mathcal{T},c}\in\bar{\mathcal{T}}_{p}^{2,1,1}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c end_POSTSUBSCRIPT ∈ over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 1 end_POSTSUPERSCRIPT such that fδf𝒯,cL2=𝒪(δd/2)subscriptnormsubscript𝑓𝛿subscript𝑓𝒯𝑐superscript𝐿2𝒪superscript𝛿𝑑2\norm{f_{\delta}-f_{\mathcal{T},c}}_{L^{2}}=\mathcal{O}(\delta^{d/2})∥ start_ARG italic_f start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = caligraphic_O ( italic_δ start_POSTSUPERSCRIPT italic_d / 2 end_POSTSUPERSCRIPT ).

Proof Sketch.

Given us δ𝛿\deltaitalic_δ, we have the grid 𝒢δsubscript𝒢𝛿\mathcal{G}_{\delta}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, and the cude 𝒮Gsubscript𝒮𝐺\mathcal{S}_{G}caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT for G𝒢δ𝐺subscript𝒢𝛿G\in\mathcal{G}_{\delta}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. Our proof follows two steps:

  • Quantization. For all Xd×L𝑋superscript𝑑𝐿X\in\mathbb{R}^{d\times L}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, we quantize it to a finite set:

    • If X𝒮G[0,1]d×L𝑋subscript𝒮𝐺superscript01𝑑𝐿X\in\mathcal{S}_{G}\subset[0,1]^{d\times L}italic_X ∈ caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ⊂ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, we quantize it to the element G𝒢δ𝐺subscript𝒢𝛿G\in\mathcal{G}_{\delta}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT.

    • If X[0,1]d×L𝑋superscript01𝑑𝐿X\notin[0,1]^{d\times L}italic_X ∉ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, we quantize it to an element out of 𝒢δsubscript𝒢𝛿\mathcal{G}_{\delta}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT.

  • Map**. For any G𝒢δ𝐺subscript𝒢𝛿G\in\mathcal{G}_{\delta}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we map it to the desired output AGsubscript𝐴𝐺A_{G}italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT.

For Quantization, We achieve by a series of modified feed-forward layers. We show this in Section E.2.1.

For Map**, we follow two steps:

  • For any GG𝒢δ𝐺superscript𝐺subscript𝒢𝛿G\neq G^{\prime}\in\mathcal{G}_{\delta}italic_G ≠ italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we use a “contextual map**” qc()subscript𝑞𝑐q_{c}(\cdot)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( ⋅ ) (defined as Definition E.4), which maps all the elements in qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) and qc(G)subscript𝑞𝑐superscript𝐺q_{c}(G^{\prime})italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) to different value. Then we use a series of modified self-attention layers to achieve “contextual map**.” We show this in Section E.2.2.

    Definition E.4 (Contextual Map**).

    Consider a finite set 𝒢δd×Lsubscript𝒢𝛿superscript𝑑𝐿\mathcal{G}_{\delta}\in\mathbb{R}^{d\times L}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. A map qc:𝒢δ1×L:subscript𝑞𝑐subscript𝒢𝛿superscript1𝐿q_{c}:\mathcal{G}_{\delta}\rightarrow\mathbb{R}^{1\times L}italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT : caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT → blackboard_R start_POSTSUPERSCRIPT 1 × italic_L end_POSTSUPERSCRIPT defines a contextual map** if the map satisfies the following:

    • For any G𝒢δ𝐺subscript𝒢𝛿G\in\mathcal{G}_{\delta}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, the entries in qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) are all distinct.

    • For any GG𝒢δ𝐺superscript𝐺subscript𝒢𝛿G\neq G^{\prime}\in\mathcal{G}_{\delta}italic_G ≠ italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, all entries of qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) and qc(G)subscript𝑞𝑐superscript𝐺q_{c}(G^{\prime})italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) are distinct.

  • For any G𝒢δ𝐺subscript𝒢𝛿G\in\mathcal{G}_{\delta}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we use a series of modified feed-forward layers to map qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) to AGsubscript𝐴𝐺A_{G}italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT. We show this in Section E.2.3.

Remark E.1.

Our proof differs from (Yun et al., 2020) in one aspect: while Proposition 4 in (Yun et al., 2020) uses a transformer network without positional encoding, we add positional encoding to complete our proof.

E.2.1 Quantization by Modified Feed-forward Layers

We use a series of modified feed-forward layers in 𝒯¯pr,m,lsuperscriptsubscript¯𝒯𝑝𝑟𝑚𝑙\bar{\mathcal{T}}_{p}^{r,m,l}over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT to quantize an input Xd×L𝑋superscript𝑑𝐿X\in\mathbb{R}^{d\times L}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT to an element G𝐺Gitalic_G in a grid:

{J,0,δ,,1δ}d×L,superscript𝐽0𝛿1𝛿𝑑𝐿\displaystyle\{-J,0,\delta,\dots,1-\delta\}^{d\times L},{ - italic_J , 0 , italic_δ , … , 1 - italic_δ } start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT ,

where J>L>0𝐽𝐿0J>L>0italic_J > italic_L > 0 is a number large enough to be determined later. We achieve this via two steps.

  • Step 1: Map the element out of [0,1)01[0,1)[ 0 , 1 ) to J𝐽-J- italic_J.

    We use eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to represent the standard unit vector where the i𝑖iitalic_i-th element is 1111. For the i𝑖iitalic_i-th row of X𝑋Xitalic_X, we define the following feed-forward layer to achieve our aim.

    Definition E.5 (Feed-forward Layer 1).

    The vector eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT acts as the weight parameters and ζ1()subscript𝜁1\zeta_{1}(\cdot)italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ) acts as the activation function in the feed-forward layer.

    XX+eiζ1(eiX),ζ1(t)={tJfor t<0 or t1,0otherwise.formulae-sequence𝑋𝑋subscript𝑒𝑖subscript𝜁1superscriptsubscript𝑒𝑖top𝑋subscript𝜁1𝑡cases𝑡𝐽for 𝑡0 or 𝑡10otherwise\displaystyle X\rightarrow X+e_{i}\zeta_{1}(e_{i}^{\top}X),\leavevmode\nobreak% \ \leavevmode\nobreak\ \zeta_{1}(t)=\begin{cases}-t-J&\text{for }t<0\text{ or % }t\geq 1,\\ 0&\text{otherwise}.\end{cases}italic_X → italic_X + italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X ) , italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) = { start_ROW start_CELL - italic_t - italic_J end_CELL start_CELL for italic_t < 0 or italic_t ≥ 1 , end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL otherwise . end_CELL end_ROW (E.2)

    We take i=1𝑖1i=1italic_i = 1 as an example to give the specific calculation. We denote X=(xi,j)d×L𝑋subscriptsubscript𝑥𝑖𝑗𝑑𝐿X=(x_{i,j})_{d\times L}italic_X = ( italic_x start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT, then we have

    FF(X)FF𝑋\displaystyle\leavevmode\nobreak\ {\rm FF}(X)roman_FF ( italic_X ) =X+(100)(ζ1(x1,1),ζ1(x1,2),,ζ1(x1,L))absent𝑋matrix100matrixsubscript𝜁1subscript𝑥11subscript𝜁1subscript𝑥12subscript𝜁1subscript𝑥1𝐿\displaystyle=X+\begin{pmatrix}1\\ 0\\ \vdots\\ 0\end{pmatrix}\begin{pmatrix}\zeta_{1}(x_{1,1}),&\zeta_{1}(x_{1,2}),&\cdots,&% \zeta_{1}(x_{1,L})\end{pmatrix}= italic_X + ( start_ARG start_ROW start_CELL 1 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT ) , end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT ) , end_CELL start_CELL ⋯ , end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , italic_L end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG )
    =X+(ζ1(x1,1)ζ1(x1,2)ζ1(x1,L)000000).absent𝑋matrixsubscript𝜁1subscript𝑥11subscript𝜁1subscript𝑥12subscript𝜁1subscript𝑥1𝐿000000\displaystyle=X+\begin{pmatrix}\zeta_{1}(x_{1,1})&\zeta_{1}(x_{1,2})&\cdots&% \zeta_{1}(x_{1,L})\\ 0&0&\cdots&0\\ \vdots&\vdots&\vdots&\vdots\\ 0&0&\cdots&0\end{pmatrix}.= italic_X + ( start_ARG start_ROW start_CELL italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT ) end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , italic_L end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL end_ROW end_ARG ) .

    In the first row of X𝑋Xitalic_X, the above layer transform the element that is out of [0,1)01[0,1)[ 0 , 1 ) to J𝐽-J- italic_J.

    We stack the above layers together for i=1,2,,d𝑖12𝑑i=1,2,\dots,ditalic_i = 1 , 2 , … , italic_d. If the element of X𝑋Xitalic_X is out of [0,1)01[0,1)[ 0 , 1 ), the series of layers maps it to J𝐽Jitalic_J.

  • Step 2: Map the element in [0,1)01[0,1)[ 0 , 1 ) to {0,δ,2δ,,1δ}0𝛿2𝛿1𝛿\{0,\delta,2\delta,\dots,1-\delta\}{ 0 , italic_δ , 2 italic_δ , … , 1 - italic_δ }.

    For the i𝑖iitalic_i-th row of X𝑋Xitalic_X, we take k=0,1,,1/δ1𝑘011𝛿1k=0,1,\dots,1/\delta-1italic_k = 0 , 1 , … , 1 / italic_δ - 1 respectively, and define the following layer.

    Definition E.6 (Feed-forward Layer 2).

    The vector eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT acts as the weight parameters and ζ2()subscript𝜁2\zeta_{2}(\cdot)italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ ) acts as the activation function in the feed-forward layer.

    XX+eiζ2(eiXkδ𝟙n),ζ2(t)={0t<0 or tδt0t<δ.formulae-sequence𝑋𝑋subscript𝑒𝑖subscript𝜁2superscriptsubscript𝑒𝑖top𝑋𝑘𝛿superscriptsubscript1𝑛topsubscript𝜁2𝑡cases0𝑡0 or 𝑡𝛿𝑡0𝑡𝛿\displaystyle X\rightarrow X+e_{i}\zeta_{2}(e_{i}^{\top}X-k\delta\mathds{1}_{n% }^{\top}),\leavevmode\nobreak\ \leavevmode\nobreak\ \zeta_{2}(t)=\begin{cases}% 0&t<0\text{ or }t\geq\delta\\ -t&0\leq t<\delta.\end{cases}italic_X → italic_X + italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X - italic_k italic_δ blackboard_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_t ) = { start_ROW start_CELL 0 end_CELL start_CELL italic_t < 0 or italic_t ≥ italic_δ end_CELL end_ROW start_ROW start_CELL - italic_t end_CELL start_CELL 0 ≤ italic_t < italic_δ . end_CELL end_ROW (E.3)

    We take i=1,k=1formulae-sequence𝑖1𝑘1i=1,k=1italic_i = 1 , italic_k = 1 as an example, and give the specific calculation.

    FF(X)FF𝑋\displaystyle{\rm FF}(X)roman_FF ( italic_X ) =X+(100)(ζ2(x1,1δ)ζ2(x1,2δ)ζ2(x1,Lδ))absent𝑋matrix100matrixsubscript𝜁2subscript𝑥11𝛿subscript𝜁2subscript𝑥12𝛿subscript𝜁2subscript𝑥1𝐿𝛿\displaystyle=X+\begin{pmatrix}1\\ 0\\ \vdots\\ 0\end{pmatrix}\begin{pmatrix}\zeta_{2}(x_{1,1}-\delta)&\zeta_{2}(x_{1,2}-% \delta)&\cdots&\zeta_{2}(x_{1,L}-\delta)\end{pmatrix}= italic_X + ( start_ARG start_ROW start_CELL 1 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT - italic_δ ) end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT - italic_δ ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , italic_L end_POSTSUBSCRIPT - italic_δ ) end_CELL end_ROW end_ARG )
    =X+(ζ2(x1,1δ)ζ2(x1,2δ)ζ2(x1,Lδ)000000).absent𝑋matrixsubscript𝜁2subscript𝑥11𝛿subscript𝜁2subscript𝑥12𝛿subscript𝜁2subscript𝑥1𝐿𝛿000000\displaystyle=X+\begin{pmatrix}\zeta_{2}(x_{1,1}-\delta)&\zeta_{2}(x_{1,2}-% \delta)&\cdots&\zeta_{2}(x_{1,L}-\delta)\\ 0&0&\cdots&0\\ \vdots&\vdots&\vdots&\vdots\\ 0&0&\cdots&0\end{pmatrix}.= italic_X + ( start_ARG start_ROW start_CELL italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT - italic_δ ) end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT - italic_δ ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_ζ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 , italic_L end_POSTSUBSCRIPT - italic_δ ) end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL end_ROW end_ARG ) .

    In the first row of X𝑋Xitalic_X, the above layer transform the element in [δ,2δ]𝛿2𝛿[\delta,2\delta][ italic_δ , 2 italic_δ ] to δ𝛿\deltaitalic_δ.

    We stack the above layers together for i=1,2,,d𝑖12𝑑i=1,2,\dots,ditalic_i = 1 , 2 , … , italic_d and k=0,1,,1/δ1𝑘011𝛿1k=0,1,\dots,1/\delta-1italic_k = 0 , 1 , … , 1 / italic_δ - 1. If the element of X𝑋Xitalic_X is in [kδ,(k+1)δ]𝑘𝛿𝑘1𝛿[k\delta,(k+1)\delta][ italic_k italic_δ , ( italic_k + 1 ) italic_δ ], the series layers maps it to kδ𝑘𝛿k\deltaitalic_k italic_δ.

Combining above two parts, we achieve our goal with d/δ+d𝑑𝛿𝑑d/\delta+ditalic_d / italic_δ + italic_d feed-forward layers. We denote the d/δ+d𝑑𝛿𝑑d/\delta+ditalic_d / italic_δ + italic_d series layers as f𝒯,c1subscript𝑓𝒯𝑐1f_{\mathcal{T},c1}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 1 end_POSTSUBSCRIPT.

E.2.2 Contextual Map** by Modified Self-attention Layers

In our attention layers, we use the following positional encoding Ed×L𝐸superscript𝑑𝐿E\in\mathbb{R}^{d\times L}italic_E ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT.

E=(012L1012L1012L1).𝐸matrix012𝐿1012𝐿1missing-subexpression012𝐿1\displaystyle E=\begin{pmatrix}0&1&2&\cdots&L-1\\ 0&1&2&\cdots&L-1\\ \vdots&\vdots&\vdots&&\vdots\\ 0&1&2&\cdots&L-1\end{pmatrix}.italic_E = ( start_ARG start_ROW start_CELL 0 end_CELL start_CELL 1 end_CELL start_CELL 2 end_CELL start_CELL ⋯ end_CELL start_CELL italic_L - 1 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 1 end_CELL start_CELL 2 end_CELL start_CELL ⋯ end_CELL start_CELL italic_L - 1 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 1 end_CELL start_CELL 2 end_CELL start_CELL ⋯ end_CELL start_CELL italic_L - 1 end_CELL end_ROW end_ARG ) . (E.4)

According to Section E.2.1, the output of f𝒯,c1subscript𝑓𝒯𝑐1f_{\mathcal{T},c1}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 1 end_POSTSUBSCRIPT is in the grid {J,0,δ,,1δ}d×Lsuperscript𝐽0𝛿1𝛿𝑑𝐿\{-J,0,\delta,\dots,1-\delta\}^{d\times L}{ - italic_J , 0 , italic_δ , … , 1 - italic_δ } start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT. For any X𝑋Xitalic_X in this grid, the first column of X+E𝑋𝐸X+Eitalic_X + italic_E is in

{J,0,δ,,1δ}d,superscript𝐽0𝛿1𝛿𝑑\displaystyle\{-J,0,\delta,\dots,1-\delta\}^{d},{ - italic_J , 0 , italic_δ , … , 1 - italic_δ } start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ,

and the second column is in

{J+1,1,1+δ,,2δ}d.superscript𝐽111𝛿2𝛿𝑑\displaystyle\{-J+1,1,1+\delta,\dots,2-\delta\}^{d}.{ - italic_J + 1 , 1 , 1 + italic_δ , … , 2 - italic_δ } start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT .

For the other columns, the results are similar.

For i=0,1,,L1𝑖01𝐿1i=0,1,\dots,L-1italic_i = 0 , 1 , … , italic_L - 1, we use the following notation:

[i:δ:i+1δ]J{iJ,i,i+δ,,i+1δ}.\displaystyle[i:\delta:i+1-\delta]_{J}\coloneqq\{i-J,i,i+\delta,\dots,i+1-% \delta\}.[ italic_i : italic_δ : italic_i + 1 - italic_δ ] start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT ≔ { italic_i - italic_J , italic_i , italic_i + italic_δ , … , italic_i + 1 - italic_δ } .

The we define the grid 𝒢δ+superscriptsubscript𝒢𝛿\mathcal{G}_{\delta}^{+}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT as the following.

Definition E.7 (Grid 𝒢δ+superscriptsubscript𝒢𝛿\mathcal{G}_{\delta}^{+}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT).

X+E𝑋𝐸X+Eitalic_X + italic_E is in the grid:

𝒢δ+[0:δ:1δ]Jd×[1:δ:2δ]Jd××[L1:δ:Lδ]Jd.\displaystyle\mathcal{G}_{\delta}^{+}\coloneqq[0:\delta:1-\delta]_{J}^{d}% \times[1:\delta:2-\delta]_{J}^{d}\times\cdots\times[L-1:\delta:L-\delta]_{J}^{% d}.caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ≔ [ 0 : italic_δ : 1 - italic_δ ] start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × [ 1 : italic_δ : 2 - italic_δ ] start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × ⋯ × [ italic_L - 1 : italic_δ : italic_L - italic_δ ] start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT .

Next, we show that the modified attention layer computes contextual map** (Definition E.4) for 𝒢δ+superscriptsubscript𝒢𝛿\mathcal{G}_{\delta}^{+}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT. For i=1,2,,L1𝑖12𝐿1i=1,2,\dots,L-1italic_i = 1 , 2 , … , italic_L - 1, we use the following notation:

[i:δ:i+1δ]{i,i+δ,i+2δ,,i+1δ}.\displaystyle[i:\delta:i+1-\delta]\coloneqq\{i,i+\delta,i+2\delta,\dots,i+1-% \delta\}.[ italic_i : italic_δ : italic_i + 1 - italic_δ ] ≔ { italic_i , italic_i + italic_δ , italic_i + 2 italic_δ , … , italic_i + 1 - italic_δ } .
Lemma E.4 (Modified from Lemma 6 of (Yun et al., 2020)).

We consider the following subset of 𝒢δ+superscriptsubscript𝒢𝛿\mathcal{G}_{\delta}^{+}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT:

𝒢~δ:=[0:δ:1δ]d×[1:δ:2δ]d××[L1:δ:Lδ]dL.\displaystyle\widetilde{\mathcal{G}}_{\delta}:=\underbrace{[0:\delta:1-\delta]% ^{d}\times[1:\delta:2-\delta]^{d}\times\cdots\times[L-1:\delta:L-\delta]^{d}}_% {L}.over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT := under⏟ start_ARG [ 0 : italic_δ : 1 - italic_δ ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × [ 1 : italic_δ : 2 - italic_δ ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × ⋯ × [ italic_L - 1 : italic_δ : italic_L - italic_δ ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT .

Assume that L2𝐿2L\geq 2italic_L ≥ 2 and δ12superscript𝛿12\delta^{-1}\geq 2italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≥ 2. Then, there exist a function f𝒯,c2:d×Ld×L:subscript𝑓𝒯𝑐2superscript𝑑𝐿superscript𝑑𝐿f_{\mathcal{T},c2}:\mathbb{R}^{d\times L}\to\mathbb{R}^{d\times L}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT composed of δd+1superscript𝛿𝑑1\delta^{-d}+1italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT + 1 modified attention layers (Definition E.3), a vector ud𝑢superscript𝑑u\in\mathbb{R}^{d}italic_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, and two constants tl,trsubscript𝑡𝑙subscript𝑡𝑟t_{l},t_{r}\in\mathbb{R}italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ∈ blackboard_R (0<tl<tr0subscript𝑡𝑙subscript𝑡𝑟0<t_{l}<t_{r}0 < italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT), such that qc(G)uf𝒯,c2(G),G𝒢δ+formulae-sequencesubscript𝑞𝑐𝐺superscript𝑢topsubscript𝑓𝒯𝑐2𝐺𝐺superscriptsubscript𝒢𝛿q_{c}(G)\coloneqq u^{\top}f_{\mathcal{T},c2}(G),G\in\mathcal{G}_{\delta}^{+}italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) ≔ italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) , italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT satisfies the following properties:

  1. 1.

    For any G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, the entries of qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) are all distinct.

  2. 2.

    For any different G,G𝒢~δ𝐺superscript𝐺subscript~𝒢𝛿G,G^{\prime}\!\in\!\widetilde{\mathcal{G}}_{\delta}italic_G , italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, all entries of qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ), qc(G)subscript𝑞𝑐superscript𝐺q_{c}(G^{\prime})italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) are distinct.

  3. 3.

    For any G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, all the entries of qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) are in [tl,tr]subscript𝑡𝑙subscript𝑡𝑟[t_{l},t_{r}][ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ].

  4. 4.

    For any G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, all the entries of qc(G)subscript𝑞𝑐𝐺q_{c}(G)italic_q start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_G ) are outside [tl,tr]subscript𝑡𝑙subscript𝑡𝑟[t_{l},t_{r}][ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ].

Proof.

See Section E.5.3 for a detailed proof. ∎

Remark E.2.

Our proof differs from (Yun et al., 2020) in one aspect: the original (Yun et al., 2020, Lemma 6) does not include positional encoding (E.4). We add (E.4) to the input of the attention layer.

E.2.3 Map to the Desired Output by Modified Feed-forward Layers

Next, we show that a series of feed-forward layers map output of modified attention layers f𝒯,c2subscript𝑓𝒯𝑐2f_{\mathcal{T},c2}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT to the desired output of function fδsubscript𝑓superscript𝛿f_{\delta^{\star}}italic_f start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.

Lemma E.5 (Lemma 7 of (Yun et al., 2020)).

There exists a function f𝒯,c3:d×Ld×L:subscript𝑓𝒯𝑐3superscript𝑑𝐿superscript𝑑𝐿f_{\mathcal{T},c3}:\mathbb{R}^{d\times L}\to\mathbb{R}^{d\times L}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 3 end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT composed of 𝒪(L(1/δ)dL/L!)𝒪𝐿superscript1𝛿𝑑𝐿𝐿\mathcal{O}(L(1/\delta)^{dL}/L!)caligraphic_O ( italic_L ( 1 / italic_δ ) start_POSTSUPERSCRIPT italic_d italic_L end_POSTSUPERSCRIPT / italic_L ! ) modified feed-forward layers, such that

f𝒯,c3f𝒯,c2(G)={AG if G𝒢~δ,𝟎d×L if G𝒢δ+𝒢~δ.subscript𝑓𝒯𝑐3subscript𝑓𝒯𝑐2𝐺casessubscript𝐴𝐺 if 𝐺subscript~𝒢𝛿subscript0𝑑𝐿 if 𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿\displaystyle f_{\mathcal{T},c3}\circ f_{\mathcal{T},c2}(G)=\begin{cases}A_{G}% &\text{ if }G\in\widetilde{\mathcal{G}}_{\delta},\\ \mathbf{0}_{d\times L}&\text{ if }G\in\mathcal{G}^{+}_{\delta}\setminus% \widetilde{\mathcal{G}}_{\delta}.\end{cases}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 3 end_POSTSUBSCRIPT ∘ italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) = { start_ROW start_CELL italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT end_CELL start_CELL if italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL bold_0 start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT end_CELL start_CELL if italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT . end_CELL end_ROW
Proof.

See Section E.5.4 for a detailed proof. ∎

From above conclusions, we have the following lemma for the required number of layers in modified transformer.

Lemma E.6 ((Yun et al., 2020)).

From the proof of Lemma E.3, if we want to achieve a approximation error 𝒪(δd/2)𝒪superscript𝛿𝑑2\mathcal{O}(\delta^{d/2})caligraphic_O ( italic_δ start_POSTSUPERSCRIPT italic_d / 2 end_POSTSUPERSCRIPT ) by the modified transformer, we need 𝒪(δ1)𝒪superscript𝛿1\mathcal{O}(\delta^{-1})caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) modified feed-forward layers in f𝒯,c1subscript𝑓𝒯𝑐1f_{\mathcal{T},c1}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 1 end_POSTSUBSCRIPT, 𝒪(δd)𝒪superscript𝛿𝑑\mathcal{O}(\delta^{-d})caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ) modified self-attention layers in f𝒯,c2subscript𝑓𝒯𝑐2f_{\mathcal{T},c2}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT, and 𝒪(δdL)𝒪superscript𝛿𝑑𝐿\mathcal{O}(\delta^{-dL})caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT ) modified feed-forward layers in f𝒯,c3subscript𝑓𝒯𝑐3f_{\mathcal{T},c3}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 3 end_POSTSUBSCRIPT.

Proof.

By the proof of Lemma E.3, we complete the proof. ∎

E.3 Standard Transformers Approximate Modified Transformers

In this subsection, we show that standard neural network layers are able to approximate the modified self-attention layers and the modified feed-forward layers (Definition E.3). We have the following Lemma E.7.

Lemma E.7 (Lemma 9 of (Yun et al., 2020)).

For each f𝒯,c𝒯¯p2,1,1subscript𝑓𝒯𝑐superscriptsubscript¯𝒯𝑝211f_{\mathcal{T},c}\in\bar{\mathcal{T}}_{p}^{2,1,1}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c end_POSTSUBSCRIPT ∈ over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 1 end_POSTSUPERSCRIPT and any ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, there exists f𝒯𝒯p2,1,4subscript𝑓𝒯superscriptsubscript𝒯𝑝214f_{\mathcal{T}}\in\mathcal{T}_{p}^{2,1,4}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT such that f𝒯f𝒯,cL2ϵ/3subscriptnormsubscript𝑓𝒯subscript𝑓𝒯𝑐superscript𝐿2italic-ϵ3\norm{f_{\mathcal{T}}-f_{\mathcal{T},c}}_{L^{2}}\leq\epsilon/3∥ start_ARG italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ≤ italic_ϵ / 3.

Proof.

See Section E.5.5 for a detailed proof. ∎

E.4 All Together: Standard Transformers Approximate Compact-Supported Continuous Functions

We summarize the results of Lemmas E.2, E.3 and E.7, and thus prove Lemma E.1. Furthermore, to achieve the ϵitalic-ϵ\epsilonitalic_ϵ approximation error in Lemma E.1, we take δ=𝒪(ϵ2/d)𝛿𝒪superscriptitalic-ϵ2𝑑\delta=\mathcal{O}(\epsilon^{2/d})italic_δ = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT 2 / italic_d end_POSTSUPERSCRIPT ) in Lemma E.3.

E.5 Supplementary Proofs

Here we first present two preliminaries: selective shift operation and bijective column ID map** in Section E.5.1 to proceed with our proof. Then we show the proof of Lemma E.2 in Section E.5.2, proof of Lemma E.4 in Section E.5.3, proof of Lemma E.5 in Section E.5.4, and proof of Lemma E.7 in Section E.5.5.

E.5.1 Preliminaries

We give the definition of two preliminaries: selective shift operation and bijective column ID map**.

Selective Shift Operation.

This operation refers to shifting certain entries of the input selectively.

To achieve this, we consider the following function ξ(;):d×Ld×L:𝜉superscript𝑑𝐿superscript𝑑𝐿\xi(\cdot;\cdot):\mathbb{R}^{d\times L}\rightarrow\mathbb{R}^{d\times L}italic_ξ ( ⋅ ; ⋅ ) : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT.

ξ(X;bQ)=e1uXσH[(uX)(uXbQ𝟙n)],𝜉𝑋subscript𝑏𝑄subscript𝑒1superscript𝑢top𝑋subscript𝜎𝐻delimited-[]superscriptsuperscript𝑢top𝑋topsuperscript𝑢top𝑋subscript𝑏𝑄superscriptsubscript1𝑛top\displaystyle\xi(X;b_{Q})=e_{1}u^{\top}X\sigma_{H}\left[(u^{\top}X)^{\top}(u^{% \top}X-b_{Q}\mathds{1}_{n}^{\top})\right],italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) = italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X italic_σ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT [ ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X - italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ] , (E.5)

where Xd×L𝑋superscript𝑑𝐿X\in\mathbb{R}^{d\times L}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, e1=(1,0,0,,0)dsubscript𝑒1superscript1000topsuperscript𝑑e_{1}=(1,0,0,\cdots,0)^{\top}\in\mathbb{R}^{d}italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ( 1 , 0 , 0 , ⋯ , 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, bQsubscript𝑏𝑄b_{Q}\in\mathbb{R}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R, and ud𝑢superscript𝑑u\in\mathbb{R}^{d}italic_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a vector to be determined.

To see the output, we consider the j𝑗jitalic_j-th column of uXσH[(uX)(uXbQ𝟙n)]superscript𝑢top𝑋subscript𝜎𝐻delimited-[]superscriptsuperscript𝑢top𝑋topsuperscript𝑢top𝑋subscript𝑏𝑄superscriptsubscript1𝑛topu^{\top}X\sigma_{H}\left[(u^{\top}X)^{\top}(u^{\top}X-b_{Q}\mathds{1}_{n}^{% \top})\right]italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X italic_σ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT [ ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X - italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ]:

  • If uX:,j>bQsuperscript𝑢topsubscript𝑋:𝑗subscript𝑏𝑄u^{\top}X_{:,j}>b_{Q}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT > italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, it calculates argmaxargmax\mathop{\mathrm{argmax}}roman_argmax of uXsuperscript𝑢top𝑋u^{\top}Xitalic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X;

  • If uX:,j<bQsuperscript𝑢topsubscript𝑋:𝑗subscript𝑏𝑄u^{\top}X_{:,j}<b_{Q}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, it calculates argminargmin\mathop{\mathrm{argmin}}roman_argmin of uXsuperscript𝑢top𝑋u^{\top}Xitalic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X.

With e1subscript𝑒1e_{1}italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, all rows of ξ(X;bQ)𝜉𝑋subscript𝑏𝑄\xi(X;b_{Q})italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) except the first row are zero. We consider the j𝑗jitalic_j-th entry of the first row in ξ(X;bQ)𝜉𝑋subscript𝑏𝑄\xi(X;b_{Q})italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ), which is denoted as ξ(X;bQ)1,j𝜉subscript𝑋subscript𝑏𝑄1𝑗\xi(X;b_{Q})_{1,j}italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT. Then for all j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we have

ξ(X;bQ)1,j=uXσH[(uX)(uX:,jbQ)]={maxkuX:,k if uX:,j>bQ,minkuX:,k if uX:,j<bQ.𝜉subscript𝑋subscript𝑏𝑄1𝑗superscript𝑢top𝑋subscript𝜎𝐻delimited-[]superscriptsuperscript𝑢top𝑋topsuperscript𝑢topsubscript𝑋:𝑗subscript𝑏𝑄casessubscript𝑘superscript𝑢topsubscript𝑋:𝑘 if superscript𝑢topsubscript𝑋:𝑗subscript𝑏𝑄subscript𝑘superscript𝑢topsubscript𝑋:𝑘 if superscript𝑢topsubscript𝑋:𝑗subscript𝑏𝑄\displaystyle\xi(X;b_{Q})_{1,j}=u^{\top}X\sigma_{H}\left[(u^{\top}X)^{\top}(u^% {\top}X_{:,j}-b_{Q})\right]=\begin{cases}\max_{k}u^{\top}X_{:,k}&\text{ if }u^% {\top}X_{:,j}>b_{Q},\\ \min_{k}u^{\top}X_{:,k}&\text{ if }u^{\top}X_{:,j}<b_{Q}.\end{cases}italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT = italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X italic_σ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT [ ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) ] = { start_ROW start_CELL roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT end_CELL start_CELL if italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT > italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT end_CELL start_CELL if italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT . end_CELL end_ROW

From this observation, we define a function parametrized by bQsubscript𝑏𝑄b_{Q}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT and bQsubscriptsuperscript𝑏𝑄b^{\prime}_{Q}italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, where bQ<bQsubscript𝑏𝑄subscriptsuperscript𝑏𝑄b_{Q}<b^{\prime}_{Q}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT < italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT.

ξ(X;bQ,bQ):=ξ(X;bQ)ξ(X;bQ).assign𝜉𝑋subscript𝑏𝑄subscriptsuperscript𝑏𝑄𝜉𝑋subscript𝑏𝑄𝜉𝑋subscriptsuperscript𝑏𝑄\displaystyle\xi(X;b_{Q},b^{\prime}_{Q}):=\xi(X;b_{Q})-\xi(X;b^{\prime}_{Q}).italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) := italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) - italic_ξ ( italic_X ; italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) . (E.6)

Then we have

ξ(X;bQ,bQ)1,j={maxkuX:,kminkuX:,kifbQ<uX:,j<bQ,0others.𝜉subscript𝑋subscript𝑏𝑄subscriptsuperscript𝑏𝑄1𝑗casessubscript𝑘superscript𝑢topsubscript𝑋:𝑘subscript𝑘superscript𝑢topsubscript𝑋:𝑘ifsubscript𝑏𝑄superscript𝑢topsubscript𝑋:𝑗subscriptsuperscript𝑏𝑄0others\displaystyle\xi(X;b_{Q},b^{\prime}_{Q})_{1,j}=\begin{cases}\max_{k}u^{\top}X_% {:,k}-\min_{k}u^{\top}X_{:,k}&\leavevmode\nobreak\ \text{if}\leavevmode% \nobreak\ b_{Q}<u^{\top}X_{:,j}<b^{\prime}_{Q},\\ 0&\leavevmode\nobreak\ \text{others}.\end{cases}italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT = { start_ROW start_CELL roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT - roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT end_CELL start_CELL if italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT < italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL others . end_CELL end_ROW

We define an attention layer of the form XX+ξ(X;bQ,bQ)𝑋𝑋𝜉𝑋subscript𝑏𝑄subscriptsuperscript𝑏𝑄X\rightarrow X+\xi(X;b_{Q},b^{\prime}_{Q})italic_X → italic_X + italic_ξ ( italic_X ; italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ). For any column X:,jsubscript𝑋:𝑗X_{:,j}italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT, if bQ<uX:,j<bQsubscript𝑏𝑄superscript𝑢topsubscript𝑋:𝑗subscriptsuperscript𝑏𝑄b_{Q}<u^{\top}X_{:,j}<b^{\prime}_{Q}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT < italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, its first coordinate X1,jsubscript𝑋1𝑗X_{1,j}italic_X start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT is shifted up by maxkuX:,kminkuX:,ksubscript𝑘superscript𝑢topsubscript𝑋:𝑘subscript𝑘superscript𝑢topsubscript𝑋:𝑘\max_{k}u^{\top}X_{:,k}-\min_{k}u^{\top}X_{:,k}roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT - roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT, while all the other coordinates stay untouched. We call this the selective shift operation, because we can choose bQsubscript𝑏𝑄b_{Q}italic_b start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT and bQsubscriptsuperscript𝑏𝑄b^{\prime}_{Q}italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT to shift certain entries of the input selectively.

Bijective Column ID Map**.

We consider the input G𝒢δ+𝐺subscriptsuperscript𝒢𝛿G\in\mathcal{G}^{+}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT (Definition E.7). We use

J=L+3LδdL,andu=(1,δ1,δ2,,δd+1).formulae-sequence𝐽𝐿3𝐿superscript𝛿𝑑𝐿and𝑢1superscript𝛿1superscript𝛿2superscript𝛿𝑑1\displaystyle J=L+3L\delta^{-dL},\leavevmode\nobreak\ \text{and}\leavevmode% \nobreak\ u=(1,\delta^{-1},\delta^{-2},\dots,\delta^{-d+1}).italic_J = italic_L + 3 italic_L italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT , and italic_u = ( 1 , italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , italic_δ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT , … , italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT ) . (E.7)

For any j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we have the following two conclusions:

  • If Gi,j0subscript𝐺𝑖𝑗0G_{i,j}\geq 0italic_G start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ≥ 0 for all i[d]𝑖delimited-[]𝑑i\in[d]italic_i ∈ [ italic_d ], i.e., G:,j[j1:δ:jδ]dG_{:,j}\in[j-1:\delta:j-\delta]^{d}italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∈ [ italic_j - 1 : italic_δ : italic_j - italic_δ ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, then we have

    uG:,j[δj:δ:δj+δd+1δ],whereδj=(j1)(δδd+1δ1).\displaystyle u^{\top}G_{:,j}\in\left[\delta_{j}:\delta:\delta_{j}+\delta^{-d+% 1}-\delta\right],\leavevmode\nobreak\ \text{where}\leavevmode\nobreak\ \delta_% {j}=(j-1)\cdot\left(\frac{\delta-\delta^{-d+1}}{\delta-1}\right).italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∈ [ italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT : italic_δ : italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ] , where italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ( italic_j - 1 ) ⋅ ( divide start_ARG italic_δ - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT end_ARG start_ARG italic_δ - 1 end_ARG ) . (E.8)

    The map G:,juG:,jsubscript𝐺:𝑗superscript𝑢topsubscript𝐺:𝑗G_{:,j}\rightarrow u^{\top}G_{:,j}italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT → italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT from [j1:δ:jδ]d[j-1:\delta:j-\delta]^{d}[ italic_j - 1 : italic_δ : italic_j - italic_δ ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT to [δj:δ:δj+δd+1δ]delimited-[]:subscript𝛿𝑗𝛿:subscript𝛿𝑗superscript𝛿𝑑1𝛿\left[\delta_{j}:\delta:\delta_{j}+\delta^{-d+1}-\delta\right][ italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT : italic_δ : italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ] is a bijection.

  • If there exists i[d]𝑖delimited-[]𝑑i\in[d]italic_i ∈ [ italic_d ] such that Gi,j=J+jsubscript𝐺𝑖𝑗𝐽𝑗G_{i,j}=-J+jitalic_G start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = - italic_J + italic_j, then

    uG:,j3LδdL+(j1)(δd+1δ1δ)+δd+1<0.superscript𝑢topsubscript𝐺:𝑗3𝐿superscript𝛿𝑑𝐿𝑗1superscript𝛿𝑑1𝛿1𝛿superscript𝛿𝑑10\displaystyle u^{\top}G_{:,j}\leq-3L\delta^{-dL}+(j-1)\cdot\left(\frac{\delta^% {-d+1}-\delta}{1-\delta}\right)+\delta^{-d+1}<0.italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ≤ - 3 italic_L italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT + ( italic_j - 1 ) ⋅ ( divide start_ARG italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ end_ARG start_ARG 1 - italic_δ end_ARG ) + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT < 0 . (E.9)

We say that uG:,jsuperscript𝑢topsubscript𝐺:𝑗u^{\top}G_{:,j}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT gives the “column ID” for each possible value of G:,j[j1:δ:jδ]dG_{:,j}\in[j-1:\delta:j-\delta]^{d}italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∈ [ italic_j - 1 : italic_δ : italic_j - italic_δ ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

Remark E.3 (Illustration of Bijection Properity).

For the bijection property, we give the following illustration. Let G:j=(g1j,g2j,,gdj)subscript𝐺:absent𝑗superscriptsubscript𝑔1𝑗subscript𝑔2𝑗subscript𝑔𝑑𝑗topG_{:j}=(g_{1j},g_{2j},\cdots,g_{dj})^{\top}italic_G start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT = ( italic_g start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT , italic_g start_POSTSUBSCRIPT 2 italic_j end_POSTSUBSCRIPT , ⋯ , italic_g start_POSTSUBSCRIPT italic_d italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and G¯:j=(g¯1j,g¯2j,,g¯dj)subscript¯𝐺:absent𝑗superscriptsubscript¯𝑔1𝑗subscript¯𝑔2𝑗subscript¯𝑔𝑑𝑗top\bar{G}_{:j}=(\bar{g}_{1j},\bar{g}_{2j},\cdots,\bar{g}_{dj})^{\top}over¯ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT = ( over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT , over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 italic_j end_POSTSUBSCRIPT , ⋯ , over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_d italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. If uG:j=uG¯:jsuperscript𝑢topsubscript𝐺:absent𝑗superscript𝑢topsubscript¯𝐺:absent𝑗u^{\top}G_{:j}=u^{\top}\bar{G}_{:j}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT = italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over¯ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT and G:jG¯:jsubscript𝐺:absent𝑗subscript¯𝐺:absent𝑗G_{:j}\neq\bar{G}_{:j}italic_G start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT ≠ over¯ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT, we deduce

(g1jg¯1j)+δ1(g2jg¯2j)++δd+1(gdjg¯dj)=0.subscript𝑔1𝑗subscript¯𝑔1𝑗superscript𝛿1subscript𝑔2𝑗subscript¯𝑔2𝑗superscript𝛿𝑑1subscript𝑔𝑑𝑗subscript¯𝑔𝑑𝑗0\displaystyle(g_{1j}-\bar{g}_{1j})+\delta^{-1}(g_{2j}-\bar{g}_{2j})+\cdots+% \delta^{-d+1}(g_{dj}-\bar{g}_{dj})=0.( italic_g start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT 2 italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 italic_j end_POSTSUBSCRIPT ) + ⋯ + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_d italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_d italic_j end_POSTSUBSCRIPT ) = 0 . (E.10)

Because G:jG¯:jsubscript𝐺:absent𝑗subscript¯𝐺:absent𝑗G_{:j}\neq\bar{G}_{:j}italic_G start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT ≠ over¯ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : italic_j end_POSTSUBSCRIPT, then there exist a k(k<d)𝑘𝑘𝑑k\leavevmode\nobreak\ (k<d)italic_k ( italic_k < italic_d ), such that gkjg¯kjsubscript𝑔𝑘𝑗subscript¯𝑔𝑘𝑗g_{kj}\neq\bar{g}_{kj}italic_g start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT ≠ over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT and gij=g¯ij(i>k)subscript𝑔𝑖𝑗subscript¯𝑔𝑖𝑗𝑖𝑘g_{ij}=\bar{g}_{ij}(i>k)italic_g start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( italic_i > italic_k ). We have

|δk+1(gkjg¯kj)|δk+2.superscript𝛿𝑘1subscript𝑔𝑘𝑗subscript¯𝑔𝑘𝑗superscript𝛿𝑘2\displaystyle\absolutevalue{\delta^{-k+1}(g_{kj}-\bar{g}_{kj})}\geq\delta^{-k+% 2}.| start_ARG italic_δ start_POSTSUPERSCRIPT - italic_k + 1 end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT ) end_ARG | ≥ italic_δ start_POSTSUPERSCRIPT - italic_k + 2 end_POSTSUPERSCRIPT .

However,

|(g1jg¯1j)++δk+2(gk1,jg¯k1,j)|subscript𝑔1𝑗subscript¯𝑔1𝑗superscript𝛿𝑘2subscript𝑔𝑘1𝑗subscript¯𝑔𝑘1𝑗\displaystyle\leavevmode\nobreak\ \absolutevalue{(g_{1j}-\bar{g}_{1j})+\cdots+% \delta^{-k+2}(g_{k-1,j}-\bar{g}_{k-1,j})}| start_ARG ( italic_g start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT ) + ⋯ + italic_δ start_POSTSUPERSCRIPT - italic_k + 2 end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_k - 1 , italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_k - 1 , italic_j end_POSTSUBSCRIPT ) end_ARG |
\displaystyle\leq |g1jg¯1j|++|δk+2(gk1,jg¯k1,j)|subscript𝑔1𝑗subscript¯𝑔1𝑗superscript𝛿𝑘2subscript𝑔𝑘1𝑗subscript¯𝑔𝑘1𝑗\displaystyle\leavevmode\nobreak\ \absolutevalue{g_{1j}-\bar{g}_{1j}}+\cdots+% \absolutevalue{\delta^{-k+2}(g_{k-1,j}-\bar{g}_{k-1,j})}| start_ARG italic_g start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_ARG | + ⋯ + | start_ARG italic_δ start_POSTSUPERSCRIPT - italic_k + 2 end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_k - 1 , italic_j end_POSTSUBSCRIPT - over¯ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_k - 1 , italic_j end_POSTSUBSCRIPT ) end_ARG |
\displaystyle\leq (1δ)++δk+2(1δ)1𝛿superscript𝛿𝑘21𝛿\displaystyle\leavevmode\nobreak\ (1-\delta)+\cdots+\delta^{-k+2}(1-\delta)( 1 - italic_δ ) + ⋯ + italic_δ start_POSTSUPERSCRIPT - italic_k + 2 end_POSTSUPERSCRIPT ( 1 - italic_δ )
<\displaystyle<< δk+2.superscript𝛿𝑘2\displaystyle\leavevmode\nobreak\ \delta^{-k+2}.italic_δ start_POSTSUPERSCRIPT - italic_k + 2 end_POSTSUPERSCRIPT .

This contradicts with (E.10). Thus we prove the property of bijection.

E.5.2 Proof of Lemma E.2
Proof of Lemma E.2.

We restate the proof from (Yun et al., 2020) for completeness.

By the nature of the compact-supported continuous function, f𝑓fitalic_f is uniformly continuous.

Because subscriptnorm\norm{\cdot}_{\infty}∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT is equivalent to Fsubscriptnorm𝐹\norm{\cdot}_{F}∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT when the number of entries are finite, we have the following by the definition of uniform continuity.

For any ϵ/3>0italic-ϵ30\epsilon/3>0italic_ϵ / 3 > 0, there exist a δ>0superscript𝛿0\delta^{\star}>0italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT > 0, such that for any X,Yd×L𝑋𝑌superscript𝑑𝐿X,Y\in\mathbb{R}^{d\times L}italic_X , italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, and XY<δsubscriptnorm𝑋𝑌superscript𝛿\norm{X-Y}_{\infty}<\delta^{\star}∥ start_ARG italic_X - italic_Y end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT < italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, we have f(X)f(Y)F<ϵ/3subscriptnorm𝑓𝑋𝑓𝑌𝐹italic-ϵ3\norm{f(X)-f(Y)}_{F}<\epsilon/3∥ start_ARG italic_f ( italic_X ) - italic_f ( italic_Y ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT < italic_ϵ / 3.

Then we perform the following steps following Definitions E.1 and E.2:

  • We create a grid 𝒢δsubscript𝒢superscript𝛿\mathcal{G}_{\delta^{\star}}caligraphic_G start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT by choosing grid width δsuperscript𝛿\delta^{\star}italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and cube 𝒮Gsubscript𝒮𝐺\mathcal{S}_{G}caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT with respect to G𝒢δ𝐺subscript𝒢superscript𝛿G\in\mathcal{G}_{\delta^{\star}}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.

  • For any grid point G𝒢δ𝐺subscript𝒢superscript𝛿G\in\mathcal{G}_{\delta^{\star}}italic_G ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, we define CG𝒮Gsubscript𝐶𝐺subscript𝒮𝐺C_{G}\in\mathcal{S}_{G}italic_C start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ∈ caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT to be the center point of the cube 𝒮Gsubscript𝒮𝐺\mathcal{S}_{G}caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT.

  • We define a piece-wise constant function fδ(X)=L𝒢δf(CG)𝟙{X𝒮G}subscript𝑓superscript𝛿𝑋subscript𝐿subscript𝒢superscript𝛿𝑓subscript𝐶𝐺1𝑋subscript𝒮𝐺f_{\delta^{\star}}(X)=\sum\nolimits_{L\in\mathcal{G}_{\delta^{\star}}}f(C_{G})% \mathds{1}\{X\in\mathcal{S}_{G}\}italic_f start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X ) = ∑ start_POSTSUBSCRIPT italic_L ∈ caligraphic_G start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_C start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) blackboard_1 { italic_X ∈ caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT }.

Then for any X𝒮G𝑋subscript𝒮𝐺X\in\mathcal{S}_{G}italic_X ∈ caligraphic_S start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT, we have XCG<δsubscriptnorm𝑋subscript𝐶𝐺superscript𝛿\norm{X-C_{G}}_{\infty}<\delta^{\star}∥ start_ARG italic_X - italic_C start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT < italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. According to the uniform continuity, we drive

f(X)fδ(X)F=f(X)f(CG)F<ϵ/3.subscriptnorm𝑓𝑋subscript𝑓superscript𝛿𝑋𝐹subscriptnorm𝑓𝑋𝑓subscript𝐶𝐺𝐹italic-ϵ3\displaystyle\norm{f(X)-f_{\delta^{\star}}(X)}_{F}=\norm{f(X)-f(C_{G})}_{F}<% \epsilon/3.∥ start_ARG italic_f ( italic_X ) - italic_f start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT = ∥ start_ARG italic_f ( italic_X ) - italic_f ( italic_C start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT < italic_ϵ / 3 .

This implies that ffδL2<ϵ/3subscriptnorm𝑓subscript𝑓superscript𝛿superscript𝐿2italic-ϵ3\norm{f-f_{\delta^{\star}}}_{L^{2}}<\epsilon/3∥ start_ARG italic_f - italic_f start_POSTSUBSCRIPT italic_δ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT < italic_ϵ / 3 and completes the proof. ∎

E.5.3 Proof of Lemma E.4

We give the proof of Lemma E.4 by constructing the network to satisfy the requirements.

Proof of Lemma E.4.

Recall the selective shift operation in Section E.5.1, the overall idea of the construction includes two steps:

  • Step 1: For each j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we stack δdsuperscript𝛿𝑑\delta^{-d}italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT attention layers. We use the attention layer as

    δdξ(;gδ/2,g+δ/2),superscript𝛿𝑑𝜉𝑔𝛿2𝑔𝛿2\displaystyle\delta^{-d}\xi(\cdot;g-\delta/2,g+\delta/2),italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT italic_ξ ( ⋅ ; italic_g - italic_δ / 2 , italic_g + italic_δ / 2 ) , (E.11)

    for g[δj:δ:δj+δd+1δ](E.8)g\in[\delta_{j}:\delta:\delta_{j}+\delta^{-d+1}-\delta]\leavevmode\nobreak\ % \eqref{eq:map_domain}italic_g ∈ [ italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT : italic_δ : italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ] italic_( italic_) in the increasing order. The total number of layers is Lδd𝐿superscript𝛿𝑑L\delta^{-d}italic_L italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT. These layers cast G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT to L𝐿Litalic_L different entries required by Property 1 of Lemma E.4.

  • Step 2: We add an extra single-head attention layer with attention part

    Lδ(L+1)d1ξ(;0).𝐿superscript𝛿𝐿1𝑑1𝜉0\displaystyle L\delta^{-(L+1)d-1}\xi(\cdot;0).italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT italic_ξ ( ⋅ ; 0 ) . (E.12)

    This layer achieves a global shifting and casts different G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT to unique elements required by properties Property 2 of Lemma E.4.

The two operations together map 𝒢~δsubscript~𝒢𝛿\widetilde{\mathcal{G}}_{\delta}over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT and 𝒢δ+𝒢~δsuperscriptsubscript𝒢𝛿subscript~𝒢𝛿\mathcal{G}_{\delta}^{+}\setminus\widetilde{\mathcal{G}}_{\delta}caligraphic_G start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT to different sets, as required by properties 3-4 of Lemma E.4. The bounds tlsubscript𝑡𝑙t_{l}italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and trsubscript𝑡𝑟t_{r}italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT are calculated then.

Then, we give detailed proof by showing the impact of the two steps and verifying the four properties of Lemma E.4. We achieve this by making a category division of 𝒢δ+subscriptsuperscript𝒢𝛿\mathcal{G}^{+}_{\delta}caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT:

  • Category 1: G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, all entries in the point G𝐺Gitalic_G are between 00 and Lδ𝐿𝛿L-\deltaitalic_L - italic_δ.

  • Category 2: G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, the point G𝐺Gitalic_G has at least one entry that equals to J𝐽-J- italic_J.

Let u=(1,δ1,δ2,,δd+1)𝑢1superscript𝛿1superscript𝛿2superscript𝛿𝑑1u=(1,\delta^{-1},\delta^{-2},\ldots,\delta^{-d+1})italic_u = ( 1 , italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , italic_δ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT , … , italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT ), and recall that δj=(j1)(δδd+1)/(δ1)subscript𝛿𝑗𝑗1𝛿superscript𝛿𝑑1𝛿1\delta_{j}=(j-1)(\delta-\delta^{-d+1})/(\delta-1)italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ( italic_j - 1 ) ( italic_δ - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT ) / ( italic_δ - 1 ) for any j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ] in (E.8).

Category 1.

We denote gjuG:,jsubscript𝑔𝑗superscript𝑢topsubscript𝐺:𝑗g_{j}\coloneqq u^{\top}G_{:,j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ≔ italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT, then we have g1<g2<<gLsubscript𝑔1subscript𝑔2subscript𝑔𝐿g_{1}<g_{2}<\cdots<g_{L}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < ⋯ < italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT. The first δdsuperscript𝛿𝑑\delta^{-d}italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT layers sweep the set [δj:δ:δj+δd+1δ],j[L][\delta_{j}:\delta:\delta_{j}+\delta^{-d+1}-\delta],j\in[L][ italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT : italic_δ : italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ] , italic_j ∈ [ italic_L ] and apply selective shift operation on each element in the set. This means that selective shift operation will be applied to g1subscript𝑔1g_{1}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT first, then g2subscript𝑔2g_{2}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and then g3subscript𝑔3g_{3}italic_g start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, and so on, regardless of the specific values of gjsubscript𝑔𝑗g_{j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s.

  • First Shift Operation. In the first selective shift operation with g𝑔gitalic_g going through [δ1:δ:δ1+δd+1δ]delimited-[]:subscript𝛿1𝛿:subscript𝛿1superscript𝛿𝑑1𝛿[\delta_{1}:\delta:\delta_{1}+\delta^{-d+1}-\delta][ italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT : italic_δ : italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ], the (1,1)11(1,1)( 1 , 1 )-th entry of G𝐺Gitalic_G (e.g., G1,1subscript𝐺11G_{1,1}italic_G start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT) is shifted by the operation, while the other entries are left untouched. The updated value G~1,1subscript~𝐺11\widetilde{G}_{1,1}over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT is

    G~1,1=G1,1+δd[maxk(uG:,k)mink(uG:,k)]=G1,1+δd(gLg1).subscript~𝐺11subscript𝐺11superscript𝛿𝑑delimited-[]subscript𝑘superscript𝑢topsubscript𝐺:𝑘subscript𝑘superscript𝑢topsubscript𝐺:𝑘subscript𝐺11superscript𝛿𝑑subscript𝑔𝐿subscript𝑔1\displaystyle\widetilde{G}_{1,1}=G_{1,1}+\delta^{-d}\left[\max_{k}\left(u^{% \top}G_{:,k}\right)-\min_{k}\left(u^{\top}G_{:,k}\right)\right]=G_{1,1}+\delta% ^{-d}(g_{L}-g_{1}).over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT = italic_G start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT [ roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT ) - roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT ) ] = italic_G start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

    Therefore, after the operation, the output of the layer is

    (G~:,1G:,2G:,L).matrixsubscript~𝐺:1subscript𝐺:2subscript𝐺:𝐿\displaystyle\begin{pmatrix}\widetilde{G}_{:,1}&G_{:,2}&\cdots&G_{:,L}\end{% pmatrix}.( start_ARG start_ROW start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_G start_POSTSUBSCRIPT : , 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_G start_POSTSUBSCRIPT : , italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) .

    We have

    g~1subscript~𝑔1\displaystyle\widetilde{g}_{1}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT uTG~:,1absentsuperscript𝑢𝑇subscript~𝐺:1\displaystyle\coloneqq u^{T}\widetilde{G}_{:,1}≔ italic_u start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 1 end_POSTSUBSCRIPT
    =G~1,1+i=2dδi+1Gi,1absentsubscript~𝐺11superscriptsubscript𝑖2𝑑superscript𝛿𝑖1subscript𝐺𝑖1\displaystyle=\widetilde{G}_{1,1}+\sum_{i=2}^{d}\delta^{-i+1}G_{i,1}= over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_i = 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_i + 1 end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT italic_i , 1 end_POSTSUBSCRIPT
    =G1,1+δd(gLg1)+i=2dδi+1Gi,1absentsubscript𝐺11superscript𝛿𝑑subscript𝑔𝐿subscript𝑔1superscriptsubscript𝑖2𝑑superscript𝛿𝑖1subscript𝐺𝑖1\displaystyle=G_{1,1}+\delta^{-d}(g_{L}-g_{1})+\sum_{i=2}^{d}\delta^{-i+1}G_{i% ,1}= italic_G start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_i = 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_i + 1 end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT italic_i , 1 end_POSTSUBSCRIPT
    =g1+δd(gLg1).absentsubscript𝑔1superscript𝛿𝑑subscript𝑔𝐿subscript𝑔1\displaystyle=g_{1}+\delta^{-d}(g_{L}-g_{1}).= italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

    Then we deduce gL<g~1subscript𝑔𝐿subscript~𝑔1g_{L}<\widetilde{g}_{1}italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, because

    g~1subscript~𝑔1\displaystyle\widetilde{g}_{1}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT =g1+δd(gLg1)absentsubscript𝑔1superscript𝛿𝑑subscript𝑔𝐿subscript𝑔1\displaystyle=g_{1}+\delta^{-d}(g_{L}-g_{1})= italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
    0+δd[(L1)δδd+1δ1δd+1+δ]absent0superscript𝛿𝑑delimited-[]𝐿1𝛿superscript𝛿𝑑1𝛿1superscript𝛿𝑑1𝛿\displaystyle\geq 0+\delta^{-d}\left[(L-1)\cdot\frac{\delta-\delta^{-d+1}}{% \delta-1}-\delta^{-d+1}+\delta\right]≥ 0 + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT [ ( italic_L - 1 ) ⋅ divide start_ARG italic_δ - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT end_ARG start_ARG italic_δ - 1 end_ARG - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT + italic_δ ] (By (E.8))
    =δd[(L1)δ1δ+δ+(L1)δd+11δδd+1]absentsuperscript𝛿𝑑delimited-[]𝐿1𝛿1𝛿𝛿𝐿1superscript𝛿𝑑11𝛿superscript𝛿𝑑1\displaystyle=\delta^{-d}\left[(L-1)\frac{\delta}{1-\delta}+\delta+(L-1)\frac{% \delta^{-d+1}}{1-\delta}-\delta^{-d+1}\right]= italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT [ ( italic_L - 1 ) divide start_ARG italic_δ end_ARG start_ARG 1 - italic_δ end_ARG + italic_δ + ( italic_L - 1 ) divide start_ARG italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_δ end_ARG - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT ]
    δd((L1)δ1δ+δ)absentsuperscript𝛿𝑑𝐿1𝛿1𝛿𝛿\displaystyle\geq\delta^{-d}\cdot\left((L-1)\frac{\delta}{1-\delta}+\delta\right)≥ italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ⋅ ( ( italic_L - 1 ) divide start_ARG italic_δ end_ARG start_ARG 1 - italic_δ end_ARG + italic_δ )
    =(L1)δd+11δ+δd+1absent𝐿1superscript𝛿𝑑11𝛿superscript𝛿𝑑1\displaystyle=(L-1)\frac{\delta^{-d+1}}{1-\delta}+\delta^{-d+1}= ( italic_L - 1 ) divide start_ARG italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_δ end_ARG + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT
    >gL.absentsubscript𝑔𝐿\displaystyle>g_{L}.> italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT . (By δ<1𝛿1\delta<1italic_δ < 1 and (E.8))

    Thus, after updating,

    maxu(G~:,1G:,2G:,L)=max{g~1,g2,,gL}=g~1,superscript𝑢topmatrixsubscript~𝐺:1subscript𝐺:2subscript𝐺:𝐿subscript~𝑔1subscript𝑔2subscript𝑔𝐿subscript~𝑔1\max u^{\top}\begin{pmatrix}\widetilde{G}_{:,1}&G_{:,2}&\cdots&G_{:,L}\end{% pmatrix}=\max\{\widetilde{g}_{1},g_{2},\dots,g_{L}\}=\widetilde{g}_{1},roman_max italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( start_ARG start_ROW start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_G start_POSTSUBSCRIPT : , 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_G start_POSTSUBSCRIPT : , italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) = roman_max { over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT } = over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ,

    and the new minimum is g2subscript𝑔2g_{2}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

  • Second Shift Operation. In the second selective shift operation with g𝑔gitalic_g going through [δ2:δ:δ2+δd+1δ]delimited-[]:subscript𝛿2𝛿:subscript𝛿2superscript𝛿𝑑1𝛿[\delta_{2}:\delta:\delta_{2}+\delta^{-d+1}-\delta][ italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT : italic_δ : italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ], the (1,2)12(1,2)( 1 , 2 )-th entry of G𝐺Gitalic_G (e.g., G1,2subscript𝐺12G_{1,2}italic_G start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT) is shifted by the operation, while the other entries are left untouched. The updated value G~1,2subscript~𝐺12\widetilde{G}_{1,2}over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT is

    G~1,2subscript~𝐺12\displaystyle\widetilde{G}_{1,2}over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT =G1,2+δd(g~1g2)absentsubscript𝐺12superscript𝛿𝑑subscript~𝑔1subscript𝑔2\displaystyle=G_{1,2}+\delta^{-d}(\widetilde{g}_{1}-g_{2})= italic_G start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )
    =G1,2+δd(g1g2)+δ2d(gLg1).absentsubscript𝐺12superscript𝛿𝑑subscript𝑔1subscript𝑔2superscript𝛿2𝑑subscript𝑔𝐿subscript𝑔1\displaystyle=G_{1,2}+\delta^{-d}(g_{1}-g_{2})+\delta^{-2d}(g_{L}-g_{1}).= italic_G start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - 2 italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

    Therefore, after the operation, the output of the layer is

    (G~:,1G~:,2G:,L).matrixsubscript~𝐺:1subscript~𝐺:2subscript𝐺:𝐿\displaystyle\begin{pmatrix}\widetilde{G}_{:,1}&\widetilde{G}_{:,2}&\cdots&G_{% :,L}\end{pmatrix}.( start_ARG start_ROW start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 1 end_POSTSUBSCRIPT end_CELL start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_G start_POSTSUBSCRIPT : , italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) .

    We have

    g~2subscript~𝑔2\displaystyle\widetilde{g}_{2}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT uG~:,2absentsuperscript𝑢topsubscript~𝐺:2\displaystyle\coloneqq u^{\top}\widetilde{G}_{:,2}≔ italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 2 end_POSTSUBSCRIPT
    =g2+δd(g1g2)+δ2d(gLg1).absentsubscript𝑔2superscript𝛿𝑑subscript𝑔1subscript𝑔2superscript𝛿2𝑑subscript𝑔𝐿subscript𝑔1\displaystyle=g_{2}+\delta^{-d}(g_{1}-g_{2})+\delta^{-2d}(g_{L}-g_{1}).= italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - 2 italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

    Then we deduce g~1<g~2subscript~𝑔1subscript~𝑔2\widetilde{g}_{1}<\widetilde{g}_{2}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, because

    g1+δd(gLg1)<g2+δd(g1g2)+δ2d(gLg1)subscript𝑔1superscript𝛿𝑑subscript𝑔𝐿subscript𝑔1subscript𝑔2superscript𝛿𝑑subscript𝑔1subscript𝑔2superscript𝛿2𝑑subscript𝑔𝐿subscript𝑔1\displaystyle g_{1}+\delta^{-d}(g_{L}-g_{1})<g_{2}+\delta^{-d}(g_{1}-g_{2})+% \delta^{-2d}(g_{L}-g_{1})italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) < italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - 2 italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
    iff\displaystyle\iff\leavevmode\nobreak\ (δd1)(g2g1)<δd(δd1)(gLg1).superscript𝛿𝑑1subscript𝑔2subscript𝑔1superscript𝛿𝑑superscript𝛿𝑑1subscript𝑔𝐿subscript𝑔1\displaystyle(\delta^{-d}-1)(g_{2}-g_{1})<\delta^{-d}(\delta^{-d}-1)(g_{L}-g_{% 1}).( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ( italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) < italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) . (By δd>1superscript𝛿𝑑1\delta^{-d}>1italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT > 1 and gL>g2subscript𝑔𝐿subscript𝑔2g_{L}>g_{2}italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT > italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT)

    Thus, after updating,

    maxu(G~:,1G~:,2G:,L)=max{g~1,g~2,,gL}=g~2,superscript𝑢topmatrixsubscript~𝐺:1subscript~𝐺:2subscript𝐺:𝐿subscript~𝑔1subscript~𝑔2subscript𝑔𝐿subscript~𝑔2\max u^{\top}\begin{pmatrix}\widetilde{G}_{:,1}&\widetilde{G}_{:,2}&\cdots&G_{% :,L}\end{pmatrix}=\max\{\widetilde{g}_{1},\widetilde{g}_{2},\dots,g_{L}\}=% \widetilde{g}_{2},roman_max italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( start_ARG start_ROW start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 1 end_POSTSUBSCRIPT end_CELL start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_G start_POSTSUBSCRIPT : , italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) = roman_max { over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT } = over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ,

    and the new minimum is g3subscript𝑔3g_{3}italic_g start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT.

  • Repeating The Process. By repeating this process, we show that the j𝑗jitalic_j-th shift operation shifts G1,jsubscript𝐺1𝑗G_{1,j}italic_G start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT by δd(g~j1gj)superscript𝛿𝑑subscript~𝑔𝑗1subscript𝑔𝑗\delta^{-d}(\widetilde{g}_{j-1}-g_{j})italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), and we have

    g~jsubscript~𝑔𝑗\displaystyle\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT uG~:,jabsentsuperscript𝑢topsubscript~𝐺:𝑗\displaystyle\coloneqq u^{\top}\widetilde{G}_{:,j}≔ italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT
    =gj+k=1j1δkd(gjkgjk+1)+δjd(gLg1).absentsubscript𝑔𝑗superscriptsubscript𝑘1𝑗1superscript𝛿𝑘𝑑subscript𝑔𝑗𝑘subscript𝑔𝑗𝑘1superscript𝛿𝑗𝑑subscript𝑔𝐿subscript𝑔1\displaystyle=g_{j}+\sum_{k=1}^{j-1}\delta^{-kd}(g_{j-k}-g_{j-k+1})+\delta^{-% jd}(g_{L}-g_{1}).= italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_k italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - italic_j italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

    We deduce g~j1<g~jsubscript~𝑔𝑗1subscript~𝑔𝑗\widetilde{g}_{j-1}<\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT holds for all 2jL2𝑗𝐿2\leq j\leq L2 ≤ italic_j ≤ italic_L, because

    g~j1<g~jsubscript~𝑔𝑗1subscript~𝑔𝑗\displaystyle\leavevmode\nobreak\ \widetilde{g}_{j-1}<\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT
    iff\displaystyle\iff\leavevmode\nobreak\ gj1+k=2j1δkd+d(gjkgjk+1)+δ(j1)d(gLg1)subscript𝑔𝑗1superscriptsubscript𝑘2𝑗1superscript𝛿𝑘𝑑𝑑subscript𝑔𝑗𝑘subscript𝑔𝑗𝑘1superscript𝛿𝑗1𝑑subscript𝑔𝐿subscript𝑔1\displaystyle\leavevmode\nobreak\ g_{j-1}+\sum_{k=2}^{j-1}\delta^{-kd+d}(g_{j-% k}-g_{j-k+1})+\delta^{-(j-1)d}(g_{L}-g_{1})italic_g start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k = 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_k italic_d + italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - ( italic_j - 1 ) italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
    <gj+k=1j1δkd(gjkgjk+1)+δjd(gLg1)absentsubscript𝑔𝑗superscriptsubscript𝑘1𝑗1superscript𝛿𝑘𝑑subscript𝑔𝑗𝑘subscript𝑔𝑗𝑘1superscript𝛿𝑗𝑑subscript𝑔𝐿subscript𝑔1\displaystyle\leavevmode\nobreak\ <g_{j}+\sum_{k=1}^{j-1}\delta^{-kd}(g_{j-k}-% g_{j-k+1})+\delta^{-jd}(g_{L}-g_{1})< italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_k italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - italic_j italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
    iff\displaystyle\iff\leavevmode\nobreak\ k=1j1δkd+d(δd1)(gjk+1gjk)<δ(j1)d(δd1)(gLg1),superscriptsubscript𝑘1𝑗1superscript𝛿𝑘𝑑𝑑superscript𝛿𝑑1subscript𝑔𝑗𝑘1subscript𝑔𝑗𝑘superscript𝛿𝑗1𝑑superscript𝛿𝑑1subscript𝑔𝐿subscript𝑔1\displaystyle\leavevmode\nobreak\ \sum_{k=1}^{j-1}\delta^{-kd+d}(\delta^{-d}-1% )(g_{j-k+1}-g_{j-k})<\delta^{-(j-1)d}(\delta^{-d}-1)(g_{L}-g_{1}),∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_k italic_d + italic_d end_POSTSUPERSCRIPT ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ( italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT ) < italic_δ start_POSTSUPERSCRIPT - ( italic_j - 1 ) italic_d end_POSTSUPERSCRIPT ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ,

    where the last inequality holds because

    k=1j1δkd+d(gjk+1gjk)superscriptsubscript𝑘1𝑗1superscript𝛿𝑘𝑑𝑑subscript𝑔𝑗𝑘1subscript𝑔𝑗𝑘\displaystyle\leavevmode\nobreak\ \sum_{k=1}^{j-1}\delta^{-kd+d}(g_{j-k+1}-g_{% j-k})∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_k italic_d + italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT )
    <\displaystyle<< δ(j1)dk=1j1(gjk+1gjk)superscript𝛿𝑗1𝑑superscriptsubscript𝑘1𝑗1subscript𝑔𝑗𝑘1subscript𝑔𝑗𝑘\displaystyle\leavevmode\nobreak\ \delta^{-(j-1)d}\sum_{k=1}^{j-1}(g_{j-k+1}-g% _{j-k})italic_δ start_POSTSUPERSCRIPT - ( italic_j - 1 ) italic_d end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT )
    <\displaystyle<< δ(j1)d(gLg1).superscript𝛿𝑗1𝑑subscript𝑔𝐿subscript𝑔1\displaystyle\leavevmode\nobreak\ \delta^{-(j-1)d}(g_{L}-g_{1}).italic_δ start_POSTSUPERSCRIPT - ( italic_j - 1 ) italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) .

    Therefore, after the j𝑗jitalic_j-th selective shift operation, g~jsubscript~𝑔𝑗\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is the new maximum among {g~1,,g~j,gj+1,,gL}subscript~𝑔1subscript~𝑔𝑗subscript𝑔𝑗1subscript𝑔𝐿\{\widetilde{g}_{1},\dots,\widetilde{g}_{j},g_{j+1},\dots,g_{L}\}{ over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_g start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , … , italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT } and gj+1subscript𝑔𝑗1g_{j+1}italic_g start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT is the new minimum.

  • After L𝐿Litalic_L Shift Operations. After the whole L𝐿Litalic_L shift operations, the input G𝐺Gitalic_G is mapped to a new point G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG, where uG~=(g~1g~2g~L)superscript𝑢top~𝐺matrixsubscript~𝑔1subscript~𝑔2subscript~𝑔𝐿u^{\top}\widetilde{G}=\begin{pmatrix}\widetilde{g}_{1}&\widetilde{g}_{2}&\dots% &\widetilde{g}_{L}\end{pmatrix}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG = ( start_ARG start_ROW start_CELL over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) and g~1<g~2<<g~Lsubscript~𝑔1subscript~𝑔2subscript~𝑔𝐿\widetilde{g}_{1}<\widetilde{g}_{2}<\dots<\widetilde{g}_{L}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < ⋯ < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT. For the lower and upper bound of g~Lsubscript~𝑔𝐿\widetilde{g}_{L}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT, we have the following lemma.

    Lemma E.8 (Lemma 10 of (Yun et al., 2020)).

    g~L=uG~:,Lsubscript~𝑔𝐿superscript𝑢topsubscript~𝐺:𝐿\widetilde{g}_{L}=u^{\top}\widetilde{G}_{:,L}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT = italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , italic_L end_POSTSUBSCRIPT satisfies the following bounds:

    δ(L1)d+1(δd1)g~LLδ(L+1)d.superscript𝛿𝐿1𝑑1superscript𝛿𝑑1subscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑\displaystyle\delta^{-(L-1)d+1}(\delta^{-d}-1)\leq\widetilde{g}_{L}\leq L% \delta^{-(L+1)d}.italic_δ start_POSTSUPERSCRIPT - ( italic_L - 1 ) italic_d + 1 end_POSTSUPERSCRIPT ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ≤ over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ≤ italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT .

    Also, the map** from (g1g2gL)matrixsubscript𝑔1subscript𝑔2subscript𝑔𝐿\begin{pmatrix}g_{1}&g_{2}&\cdots&g_{L}\end{pmatrix}( start_ARG start_ROW start_CELL italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) to g~Lsubscript~𝑔𝐿\widetilde{g}_{L}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT is one-to-one map**.

  • Global Shifting by the Last Layer. We note that after the above L𝐿Litalic_L shift operations, there is another attention layer with attention part Lδ(L+1)d1ξ(;0)𝐿superscript𝛿𝐿1𝑑1𝜉0L\delta^{-(L+1)d-1}\xi(\cdot;0)italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT italic_ξ ( ⋅ ; 0 ). Since 0<g~1<<g~L0subscript~𝑔1subscript~𝑔𝐿0<\widetilde{g}_{1}<\cdots<\widetilde{g}_{L}0 < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < ⋯ < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT, what it does to G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG is that it adds the following to each entry in the first row of G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG:

    Lδ(L+1)d1maxkuG~:,k=Lδ(L+1)d1g~L.𝐿superscript𝛿𝐿1𝑑1subscript𝑘superscript𝑢topsubscript~𝐺:𝑘𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿\displaystyle L\delta^{-(L+1)d-1}\max_{k}u^{\top}\widetilde{G}_{:,k}=L\delta^{% -(L+1)d-1}\widetilde{g}_{L}.italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT = italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT .

    The output of this layer is defined to be the function f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ).

Now, in summary, for any G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, i[d]𝑖delimited-[]𝑑i\in[d]italic_i ∈ [ italic_d ], and j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we have

f𝒯,c2(G)i,jsubscript𝑓𝒯𝑐2subscript𝐺𝑖𝑗\displaystyle f_{\mathcal{T},c2}(G)_{i,j}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ={G1,j+δj+ if i=1,Gi,j if 2id,absentcasessubscript𝐺1𝑗superscriptsubscript𝛿𝑗 if 𝑖1subscript𝐺𝑖𝑗 if 2𝑖𝑑\displaystyle=\begin{cases}G_{1,j}+\delta_{j}^{+}&\text{ if }i=1,\\ G_{i,j}&\text{ if }2\leq i\leq d,\end{cases}= { start_ROW start_CELL italic_G start_POSTSUBSCRIPT 1 , italic_j end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT end_CELL start_CELL if italic_i = 1 , end_CELL end_ROW start_ROW start_CELL italic_G start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT end_CELL start_CELL if 2 ≤ italic_i ≤ italic_d , end_CELL end_ROW
whereδj+=k=1j1δkd(gjkgjk+1)+δjd(gLg1)+Lδ(L+1)d1g~L.wheresuperscriptsubscript𝛿𝑗superscriptsubscript𝑘1𝑗1superscript𝛿𝑘𝑑subscript𝑔𝑗𝑘subscript𝑔𝑗𝑘1superscript𝛿𝑗𝑑subscript𝑔𝐿subscript𝑔1𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿\displaystyle\leavevmode\nobreak\ \text{where}\leavevmode\nobreak\ \delta_{j}^% {+}=\sum_{k=1}^{j-1}\delta^{-kd}(g_{j-k}-g_{j-k+1})+\delta^{-jd}(g_{L}-g_{1})+% L\delta^{-(L+1)d-1}\widetilde{g}_{L}.where italic_δ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - italic_k italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_j - italic_k end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT italic_j - italic_k + 1 end_POSTSUBSCRIPT ) + italic_δ start_POSTSUPERSCRIPT - italic_j italic_d end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT .

For any G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT and j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ],

uf𝒯,c2(G):,j=g~j+Lδ(L+1)d1g~L.superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗subscript~𝑔𝑗𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿u^{\top}f_{\mathcal{T},c2}(G)_{:,j}=\widetilde{g}_{j}+L\delta^{-(L+1)d-1}% \widetilde{g}_{L}.italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT = over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT .

Next, we check the Property 1, Property 2 and Property 3 of Lemma E.4.

  • Checking Property 1 of Lemma E.4. Given any G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we already prove that

    g~1<g~2<<g~L,subscript~𝑔1subscript~𝑔2subscript~𝑔𝐿\displaystyle\widetilde{g}_{1}<\widetilde{g}_{2}<\dots<\widetilde{g}_{L},over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < ⋯ < over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ,

    so they are all distinct.

  • Checking Property 2 of Lemma E.4. Note that the upper bound on g~Lsubscript~𝑔𝐿\widetilde{g}_{L}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT from Lemma E.8 also holds for other g~jsubscript~𝑔𝑗\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s, so for all j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we have

    Lδ(L+1)d1g~Luf𝒯,c2(G):,j<Lδ(L+1)d1g~L+Lδ(L+1)d.𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑\displaystyle L\delta^{-(L+1)d-1}\widetilde{g}_{L}\leq u^{\top}f_{\mathcal{T},% c2}(G)_{:,j}<L\delta^{-(L+1)d-1}\widetilde{g}_{L}+L\delta^{-(L+1)d}.italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ≤ italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT .

    Now, from Lemma E.8, two different G,G𝒢~δ𝐺superscript𝐺subscript~𝒢𝛿G,G^{\prime}\in\widetilde{\mathcal{G}}_{\delta}italic_G , italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT map to different g~Lsubscript~𝑔𝐿\widetilde{g}_{L}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT and g~Lsubscriptsuperscript~𝑔𝐿\widetilde{g}^{\prime}_{L}over~ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT, and they differ at least by δ𝛿\deltaitalic_δ. This means that two intervals

    [Lδ(L+1)d1g~L,Lδ(L+1)d1g~L+Lδ(L+1)d),𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑\displaystyle\leavevmode\nobreak\ [L\delta^{-(L+1)d-1}\widetilde{g}_{L},L% \delta^{-(L+1)d-1}\widetilde{g}_{L}+L\delta^{-(L+1)d}),[ italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT ) ,
    [Lδ(L+1)d1g~L,Lδ(L+1)d1g~L+Lδ(L+1)d),𝐿superscript𝛿𝐿1𝑑1subscriptsuperscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑1subscriptsuperscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑\displaystyle\leavevmode\nobreak\ [L\delta^{-(L+1)d-1}\widetilde{g}^{\prime}_{% L},L\delta^{-(L+1)d-1}\widetilde{g}^{\prime}_{L}+L\delta^{-(L+1)d}),[ italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT ) ,

    are guaranteed to be disjoint, so the entries of uf𝒯,c2(G)superscript𝑢topsubscript𝑓𝒯𝑐2𝐺u^{\top}f_{\mathcal{T},c2}(G)italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) and uf𝒯,c2(G)superscript𝑢topsubscript𝑓𝒯𝑐2superscript𝐺u^{\top}f_{\mathcal{T},c2}(G^{\prime})italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) are all distinct.

    Now, we finish showing that the map f𝒯,c2()subscript𝑓𝒯𝑐2f_{\mathcal{T},c2}(\cdot)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( ⋅ ) we constructed using (1/δ)d+1superscript1𝛿𝑑1(1/\delta)^{d}+1( 1 / italic_δ ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT + 1 attention layers implements a contextual map** on 𝒢~δsubscript~𝒢𝛿\widetilde{\mathcal{G}}_{\delta}over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT.

  • Checking Property 3 of Lemma E.4. With uf𝒯,c2(G):,j[Lδ(L+1)d1g~L,Lδ(L+1)d1g~L+Lδ(L+1)d)superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑1subscript~𝑔𝐿𝐿superscript𝛿𝐿1𝑑u^{\top}f_{\mathcal{T},c2}(G)_{:,j}\in[L\delta^{-(L+1)d-1}\widetilde{g}_{L},L% \delta^{-(L+1)d-1}\widetilde{g}_{L}+L\delta^{-(L+1)d})italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∈ [ italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT ) and Lemma E.8, we show that for any G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we have

    uf𝒯,c2(G):,jLδ2(L+1)d(δd1),superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗𝐿superscript𝛿2𝐿1𝑑superscript𝛿𝑑1\displaystyle\leavevmode\nobreak\ u^{\top}f_{\mathcal{T},c2}(G)_{:,j}\geq L% \delta^{-2(L+1)d}(\delta^{-d}-1),italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ≥ italic_L italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ,
    uf𝒯,c2(G):,j<L2δ2(L+1)d1+Lδ(L+1)d.superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗superscript𝐿2superscript𝛿2𝐿1𝑑1𝐿superscript𝛿𝐿1𝑑\displaystyle\leavevmode\nobreak\ u^{\top}f_{\mathcal{T},c2}(G)_{:,j}<L^{2}% \delta^{-2(L+1)d-1}+L\delta^{-(L+1)d}.italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT .

    This proves that all uf𝒯,c2(L):,jsuperscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐿:𝑗u^{\top}f_{\mathcal{T},c2}(L)_{:,j}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_L ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT are between tlsubscript𝑡𝑙t_{l}italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and trsubscript𝑡𝑟t_{r}italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT, where

    tl=Lδ2(L+1)d(δd1),subscript𝑡𝑙𝐿superscript𝛿2𝐿1𝑑superscript𝛿𝑑1\displaystyle\leavevmode\nobreak\ t_{l}=L\delta^{-2(L+1)d}(\delta^{-d}-1),italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_L italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT - 1 ) ,
    tr=L2δ2(L+1)d1+Lδ(L+1)d.subscript𝑡𝑟superscript𝐿2superscript𝛿2𝐿1𝑑1𝐿superscript𝛿𝐿1𝑑\displaystyle\leavevmode\nobreak\ t_{r}=L^{2}\delta^{-2(L+1)d-1}+L\delta^{-(L+% 1)d}.italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT .

Category 2. Now we check Property 4 of Lemma E.4. For the input points G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, note that the point G𝐺Gitalic_G has at least one entry that equals to J+k,k[L1]𝐽𝑘𝑘delimited-[]𝐿1-J+k,k\in[L-1]- italic_J + italic_k , italic_k ∈ [ italic_L - 1 ]. Let gjuG:,jsubscript𝑔𝑗superscript𝑢topsubscript𝐺:𝑗g_{j}\coloneqq u^{\top}G_{:,j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ≔ italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT, and recall that whenever a column G:,jsubscript𝐺:𝑗G_{:,j}italic_G start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT has an entry that equals to J+k,k[L1]𝐽𝑘𝑘delimited-[]𝐿1-J+k,k\in[L-1]- italic_J + italic_k , italic_k ∈ [ italic_L - 1 ], we have gj<0subscript𝑔𝑗0g_{j}<0italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT < 0. Without loss of generality, assume that g1<0subscript𝑔10g_{1}<0italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < 0.

Because the selective shift operation is applied to each element of [0:δ:δL+δd+1δ]delimited-[]:0𝛿:subscript𝛿𝐿superscript𝛿𝑑1𝛿[0:\delta:\delta_{L}+\delta^{-d+1}-\delta][ 0 : italic_δ : italic_δ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT + italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ ], not to negative values, thus we have minkuG:,k=g1<0subscript𝑘superscript𝑢topsubscript𝐺:𝑘subscript𝑔10\min_{k}u^{\top}G_{:,k}=g_{1}<0roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT = italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < 0, g1subscript𝑔1g_{1}italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT never gets shifted upwards, and remains as the minimum for the whole time.

  • All gjsubscript𝑔𝑗g_{j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s Are Negative. When all gjsubscript𝑔𝑗g_{j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s are negative, selective shift operation never shifts the input G𝐺Gitalic_G, thus G~=G~𝐺𝐺\widetilde{G}=Gover~ start_ARG italic_G end_ARG = italic_G. Recall that uG~:,j<0superscript𝑢topsubscript~𝐺:𝑗0u^{\top}\widetilde{G}_{:,j}<0italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < 0 for all j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ]. The last layer with attention part Lδ(L+1)d1ξ(;0)𝐿superscript𝛿𝐿1𝑑1𝜉0L\delta^{-(L+1)d-1}\xi(\cdot;0)italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT italic_ξ ( ⋅ ; 0 ) adds Lδ(L+1)d1minkuG~:,k<0𝐿superscript𝛿𝐿1𝑑1subscript𝑘superscript𝑢topsubscript~𝐺:𝑘0L\delta^{-(L+1)d-1}\min_{k}u^{\top}\widetilde{G}_{:,k}<0italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT < 0 to each entry in the first row of G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG, making G~~𝐺\widetilde{G}over~ start_ARG italic_G end_ARG remain negative. Therefore, f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) satisfies uf𝒯,c2(G):,j<0<tlsuperscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗0subscript𝑡𝑙u^{\top}f_{\mathcal{T},c2}(G)_{:,j}<0<t_{l}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT < 0 < italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT for all j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ].

  • Not All gjsubscript𝑔𝑗g_{j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s Are Negative. Now consider the case where at least one gjsubscript𝑔𝑗g_{j}italic_g start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is positive. Suppose that there are k𝑘kitalic_k positive and satisfies gi1<gi2<<giksubscript𝑔subscript𝑖1subscript𝑔subscript𝑖2subscript𝑔subscript𝑖𝑘g_{i_{1}}<g_{i_{2}}<\cdots<g_{i_{k}}italic_g start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT < italic_g start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT < ⋯ < italic_g start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Thus selective shift operation does not affect gisubscript𝑔𝑖g_{i}italic_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, where i[L]{i1,,ik}𝑖delimited-[]𝐿subscript𝑖1subscript𝑖𝑘i\in[L]\setminus\{i_{1},\dots,i_{k}\}italic_i ∈ [ italic_L ] ∖ { italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }, but it shifts gi1subscript𝑔subscript𝑖1g_{i_{1}}italic_g start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT by

    δd(maxkuG:,kminkuG:,k)superscript𝛿𝑑subscript𝑘superscript𝑢topsubscript𝐺:𝑘subscript𝑘superscript𝑢topsubscript𝐺:𝑘\displaystyle\leavevmode\nobreak\ \delta^{-d}(\max_{k}u^{\top}G_{:,k}-\min_{k}% u^{\top}G_{:,k})italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT - roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT )
    \displaystyle\geq δd(2LδdL(L1)δd+1δ1δδd+1+(ik1)δd+1δ1δ)superscript𝛿𝑑2𝐿superscript𝛿𝑑𝐿𝐿1superscript𝛿𝑑1𝛿1𝛿superscript𝛿𝑑1subscript𝑖𝑘1superscript𝛿𝑑1𝛿1𝛿\displaystyle\leavevmode\nobreak\ \delta^{-d}(2L\delta^{-dL}-(L-1)\frac{\delta% ^{-d+1}-\delta}{1-\delta}-\delta^{-d+1}+(i_{k}-1)\frac{\delta^{-d+1}-\delta}{1% -\delta})italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( 2 italic_L italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT - ( italic_L - 1 ) divide start_ARG italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ end_ARG start_ARG 1 - italic_δ end_ARG - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT + ( italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - 1 ) divide start_ARG italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ end_ARG start_ARG 1 - italic_δ end_ARG ) (By (E.9))
    =\displaystyle== δd(3LδdLδd+1(Lik)δd+1δ1δ)superscript𝛿𝑑3𝐿superscript𝛿𝑑𝐿superscript𝛿𝑑1𝐿subscript𝑖𝑘superscript𝛿𝑑1𝛿1𝛿\displaystyle\leavevmode\nobreak\ \delta^{-d}(3L\delta^{-dL}-\delta^{-d+1}-(L-% i_{k})\frac{\delta^{-d+1}-\delta}{1-\delta})italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ( 3 italic_L italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT - italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - ( italic_L - italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) divide start_ARG italic_δ start_POSTSUPERSCRIPT - italic_d + 1 end_POSTSUPERSCRIPT - italic_δ end_ARG start_ARG 1 - italic_δ end_ARG )
    \displaystyle\geq δd2LδdLsuperscript𝛿𝑑2𝐿superscript𝛿𝑑𝐿\displaystyle\leavevmode\nobreak\ \delta^{-d}\cdot 2L\delta^{-dL}italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ⋅ 2 italic_L italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT (By δ12superscript𝛿12\delta^{-1}\geq 2italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≥ 2)
    =\displaystyle== 2Lδ(L+1)d.2𝐿superscript𝛿𝐿1𝑑\displaystyle\leavevmode\nobreak\ 2L\delta^{-(L+1)d}.2 italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT .

    The next shift operations shift gi2,,giksubscript𝑔subscript𝑖2subscript𝑔subscript𝑖𝑘g_{i_{2}},\dots,g_{i_{k}}italic_g start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_g start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT by an even larger amount, so at the end of the first L(1/δ)d𝐿superscript1𝛿𝑑L(1/\delta)^{d}italic_L ( 1 / italic_δ ) start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT layers, we have Lδ(L+1)dg~i1g~ik𝐿superscript𝛿𝐿1𝑑subscript~𝑔subscript𝑖1subscript~𝑔subscript𝑖𝑘L\delta^{-(L+1)d}\leq\widetilde{g}_{i_{1}}\leq\dots\leq\widetilde{g}_{i_{k}}italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT ≤ over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≤ ⋯ ≤ over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT, while g~j<0subscript~𝑔𝑗0\widetilde{g}_{j}<0over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT < 0 for all j[L]{i1,,ik}𝑗delimited-[]𝐿subscript𝑖1subscript𝑖𝑘j\in[L]\setminus\{i_{1},\dots,i_{k}\}italic_j ∈ [ italic_L ] ∖ { italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }.

    Then, we shift G𝐺Gitalic_G by the last layer. The last layer with attention part Lδ(L+1)d1ξ(;0)𝐿superscript𝛿𝐿1𝑑1𝜉0L\delta^{-(L+1)d-1}\xi(\cdot;0)italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT italic_ξ ( ⋅ ; 0 ) acts differently for negative and positive g~jsubscript~𝑔𝑗\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s. (i). For negative g~jsubscript~𝑔𝑗\widetilde{g}_{j}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s, it adds the following to g~j,j[L]{i1,,ik}subscript~𝑔𝑗𝑗delimited-[]𝐿subscript𝑖1subscript𝑖𝑘\widetilde{g}_{j},j\in[L]\setminus\{i_{1},\dots,i_{k}\}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_j ∈ [ italic_L ] ∖ { italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }:

    Lδ(L+1)d1minkuG~:,k=Lδ(L+1)d1g1<0.𝐿superscript𝛿𝐿1𝑑1subscript𝑘superscript𝑢topsubscript~𝐺:𝑘𝐿superscript𝛿𝐿1𝑑1subscript𝑔10\displaystyle L\delta^{-(L+1)d-1}\min_{k}u^{\top}\widetilde{G}_{:,k}=L\delta^{% -(L+1)d-1}g_{1}<0.italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT = italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT italic_g start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < 0 .

    This term push them further to the negative side. (ii). For positive g~isubscript~𝑔𝑖\widetilde{g}_{i}over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s, it adds

    Lδ(L+1)d1maxkuG~k=Lδ(L+1)d1g~ik2L2δ2(L+1)d1.𝐿superscript𝛿𝐿1𝑑1subscript𝑘superscript𝑢topsubscript~𝐺𝑘𝐿superscript𝛿𝐿1𝑑1subscript~𝑔subscript𝑖𝑘2superscript𝐿2superscript𝛿2𝐿1𝑑1\displaystyle L\delta^{-(L+1)d-1}\max_{k}u^{\top}\widetilde{G}_{k}=L\delta^{-(% L+1)d-1}\widetilde{g}_{i_{k}}\geq 2L^{2}\delta^{-2(L+1)d-1}.italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_g end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≥ 2 italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT .

    Thus they are all greater than or equal to 2L2δ2(L+1)d+12superscript𝐿2superscript𝛿2𝐿1𝑑12L^{2}\delta^{-2(L+1)d+1}2 italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d + 1 end_POSTSUPERSCRIPT.

    Note that

    2L2δ2(L+1)d1>tr,wheretr=L2δ2(L+1)d1+Lδ(L+1)d.formulae-sequence2superscript𝐿2superscript𝛿2𝐿1𝑑1subscript𝑡𝑟wheresubscript𝑡𝑟superscript𝐿2superscript𝛿2𝐿1𝑑1𝐿superscript𝛿𝐿1𝑑\displaystyle 2L^{2}\delta^{-2(L+1)d-1}>t_{r},\leavevmode\nobreak\ \text{where% }\leavevmode\nobreak\ t_{r}=L^{2}\delta^{-2(L+1)d-1}+L\delta^{-(L+1)d}.2 italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT > italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , where italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 ( italic_L + 1 ) italic_d - 1 end_POSTSUPERSCRIPT + italic_L italic_δ start_POSTSUPERSCRIPT - ( italic_L + 1 ) italic_d end_POSTSUPERSCRIPT .

    Then we have the final output f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) satisfies uf𝒯,c2(G):,j[tl,tr]superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗subscript𝑡𝑙subscript𝑡𝑟u^{\top}f_{\mathcal{T},c2}(G)_{:,j}\notin[t_{l},t_{r}]italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∉ [ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ], for all j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ]. This completes the verification of Property 4 of Lemma E.4.

In conclusion, we need 𝒪(Lδd)𝒪𝐿superscript𝛿𝑑\mathcal{O}(L\delta^{-d})caligraphic_O ( italic_L italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ) layers of modified self-attention layer to obtain our approximation. This completes the proof. ∎

E.5.4 Proof of Lemma E.5
Proof of Lemma E.5.

We restate the proof from (Yun et al., 2020) for completeness.

Note that |𝒢δ+|=(1/δ+1)dL<subscriptsuperscript𝒢𝛿superscript1𝛿1𝑑𝐿|\mathcal{G}^{+}_{\delta}|=(1/\delta+1)^{dL}<\infty| caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT | = ( 1 / italic_δ + 1 ) start_POSTSUPERSCRIPT italic_d italic_L end_POSTSUPERSCRIPT < ∞, so the output of f𝒯,c2(𝒢δ+)subscript𝑓𝒯𝑐2subscriptsuperscript𝒢𝛿f_{\mathcal{T},c2}(\mathcal{G}^{+}_{\delta})italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ) has finite number of distinct real values. Let M𝑀Mitalic_M be the upper bound of all these possible values. By construction of f𝒯,c2subscript𝑓𝒯𝑐2f_{\mathcal{T},c2}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT, M>0𝑀0M>0italic_M > 0.

Construct the Layers: f𝒯,c3(f𝒯,c2(G))=𝟎d×Lsubscript𝑓𝒯𝑐3subscript𝑓𝒯𝑐2𝐺subscript0𝑑𝐿f_{\mathcal{T},c3}(f_{\mathcal{T},c2}(G))=\mathbf{0}_{d\times L}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 3 end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) ) = bold_0 start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT if G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT.

According to Lemma E.4, for all j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we have uf𝒯,c2(G):,j[tl,tr]superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗subscript𝑡𝑙subscript𝑡𝑟u^{\top}f_{\mathcal{T},c2}(G)_{:,j}\in[t_{l},t_{r}]italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∈ [ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ] if G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, and uf𝒯,c2(G):,j[tl,tr]superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗subscript𝑡𝑙subscript𝑡𝑟u^{\top}f_{\mathcal{T},c2}(G)_{:,j}\notin[t_{l},t_{r}]italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ∉ [ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ] if G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. Due to this property, we add the following feed-forward layer:

Definition E.8 (Feed-forward Layer 3).

The vectors u𝑢uitalic_u and 𝟙Lsubscript1𝐿\mathds{1}_{L}blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT act as the weight parameters and ζ3()subscript𝜁3\zeta_{3}(\cdot)italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( ⋅ ) acts as the activation function in the feed-forward layer.

XX(M+1)𝟙Lζ3(uX),ζ3(t)={0 if t[tl,tr]1 if t[tl,tr].formulae-sequence𝑋𝑋𝑀1subscript1𝐿subscript𝜁3superscript𝑢top𝑋subscript𝜁3𝑡cases0 if 𝑡subscript𝑡𝑙subscript𝑡𝑟1 if 𝑡subscript𝑡𝑙subscript𝑡𝑟\displaystyle X\rightarrow X-(M+1)\mathds{1}_{L}\zeta_{3}(u^{\top}X),% \leavevmode\nobreak\ \leavevmode\nobreak\ \zeta_{3}(t)=\begin{cases}0&\text{ % if }t\in[t_{l},t_{r}]\\ 1&\text{ if }t\notin[t_{l},t_{r}].\end{cases}italic_X → italic_X - ( italic_M + 1 ) blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X ) , italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_t ) = { start_ROW start_CELL 0 end_CELL start_CELL if italic_t ∈ [ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ] end_CELL end_ROW start_ROW start_CELL 1 end_CELL start_CELL if italic_t ∉ [ italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ] . end_CELL end_ROW (E.13)
  • Case for G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. We have ζ3(uf𝒯,c2(G))=𝟙Lsubscript𝜁3superscript𝑢topsubscript𝑓𝒯𝑐2𝐺superscriptsubscript1𝐿top\zeta_{3}(u^{\top}f_{\mathcal{T},c2}(G))=\mathds{1}_{L}^{\top}italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) ) = blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, so all the entries of the input are shifted by M1𝑀1-M-1- italic_M - 1, and become strictly negative.

  • Case for G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. We have ζ3(uf𝒯,c2(G))=𝟎Lsubscript𝜁3superscript𝑢topsubscript𝑓𝒯𝑐2𝐺superscriptsubscript0𝐿top\zeta_{3}(u^{\top}f_{\mathcal{T},c2}(G))=\mathbf{0}_{L}^{\top}italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) ) = bold_0 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, so the output stays the same as the f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ).

With the input f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ), if G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, then ζ3(uf𝒯,c2(G))=𝟎Lsubscript𝜁3superscript𝑢topsubscript𝑓𝒯𝑐2𝐺superscriptsubscript0𝐿top\zeta_{3}(u^{\top}f_{\mathcal{T},c2}(G))=\mathbf{0}_{L}^{\top}italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) ) = bold_0 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, so the output stays the same as the input. If G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, then ζ3(uf𝒯,c2(G))=𝟙Lsubscript𝜁3superscript𝑢topsubscript𝑓𝒯𝑐2𝐺superscriptsubscript1𝐿top\zeta_{3}(u^{\top}f_{\mathcal{T},c2}(G))=\mathds{1}_{L}^{\top}italic_ζ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) ) = blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, so all the entries of the input are shifted by M1𝑀1-M-1- italic_M - 1, and become strictly negative.

Next, we map those negative entries to zero. For i=1,2,,d𝑖12𝑑i=1,2,\cdots,ditalic_i = 1 , 2 , ⋯ , italic_d, we add the following layer:

Definition E.9 (Feed-forward Layer 4).

The vectors u𝑢uitalic_u and eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT act as the weight parameters and ζ4()subscript𝜁4\zeta_{4}(\cdot)italic_ζ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( ⋅ ) acts as the activation function in the feed-forward layer.

XX+eiζ4((ei)X),ζ4(t)={t if t<00 if t0.formulae-sequence𝑋𝑋subscript𝑒𝑖subscript𝜁4superscriptsubscript𝑒𝑖top𝑋subscript𝜁4𝑡cases𝑡 if 𝑡00 if 𝑡0\displaystyle X\rightarrow X+e_{i}\zeta_{4}((e_{i})^{\top}X),\leavevmode% \nobreak\ \leavevmode\nobreak\ \zeta_{4}(t)=\begin{cases}-t&\text{ if }t<0\\ 0&\text{ if }t\geq 0.\end{cases}italic_X → italic_X + italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ζ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( ( italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X ) , italic_ζ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( italic_t ) = { start_ROW start_CELL - italic_t end_CELL start_CELL if italic_t < 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL if italic_t ≥ 0 . end_CELL end_ROW (E.14)

After these d𝑑ditalic_d layers, the output for G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT is a zero matrix, while the output for G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT remains f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ).

Construct the Layers: f𝒯,c3(f𝒯,c2(G))=AGsubscript𝑓𝒯𝑐3subscript𝑓𝒯𝑐2𝐺subscript𝐴𝐺f_{\mathcal{T},c3}(f_{\mathcal{T},c2}(G))=A_{G}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 3 end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) ) = italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT if G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT.

Each different G𝐺Gitalic_G is mapped to L𝐿Litalic_L unique numbers uf𝒯,c2(G)superscript𝑢topsubscript𝑓𝒯𝑐2𝐺u^{\top}f_{\mathcal{T},c2}(G)italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ), which are at least δ𝛿\deltaitalic_δ apart from each other. We map each unique number to the corresponding output column as follows. We choose one G¯𝒢~δ¯𝐺subscript~𝒢𝛿\bar{G}\in\widetilde{\mathcal{G}}_{\delta}over¯ start_ARG italic_G end_ARG ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, for each uf𝒯,c2(G¯):,jsuperscript𝑢topsubscript𝑓𝒯𝑐2subscript¯𝐺:𝑗u^{\top}f_{\mathcal{T},c2}(\bar{G})_{:,j}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT, j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], we add the following feed-forward layer.

Definition E.10 (Feed-forward Layer 5).

The vectors u𝑢uitalic_u and eisubscript𝑒𝑖e_{i}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT act as the weight parameters and ζ4()subscript𝜁4\zeta_{4}(\cdot)italic_ζ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( ⋅ ) acts as the activation function in the feed-forward layer.

X𝑋absent\displaystyle X\rightarrowitalic_X → X+((AG¯):,jf𝒯,c2(G¯):,j)ζ5(uXuf𝒯,c2(G¯):,j𝟙L),𝑋subscriptsubscript𝐴¯𝐺:𝑗subscript𝑓𝒯𝑐2subscript¯𝐺:𝑗subscript𝜁5superscript𝑢top𝑋superscript𝑢topsubscript𝑓𝒯𝑐2subscript¯𝐺:𝑗superscriptsubscript1𝐿top\displaystyle X+\left((A_{\bar{G}})_{:,j}-f_{\mathcal{T},c2}({\bar{G}})_{:,j}% \right)\zeta_{5}(u^{\top}X-u^{\top}f_{\mathcal{T},c2}(\bar{G})_{:,j}\mathds{1}% _{L}^{\top}),italic_X + ( ( italic_A start_POSTSUBSCRIPT over¯ start_ARG italic_G end_ARG end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ) italic_ζ start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X - italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , (E.15)
ζ5(t)={1δ/2t<δ/2,0others.subscript𝜁5𝑡cases1𝛿2𝑡𝛿20others\displaystyle\zeta_{5}(t)=\begin{cases}1&-\delta/2\leq t<\delta/2,\\ 0&\leavevmode\nobreak\ \text{others}.\end{cases}italic_ζ start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( italic_t ) = { start_ROW start_CELL 1 end_CELL start_CELL - italic_δ / 2 ≤ italic_t < italic_δ / 2 , end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL others . end_CELL end_ROW (E.16)
  • Case for G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. Recall that the input X𝑋Xitalic_X of this layer is f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}({G})italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ). If X𝑋Xitalic_X is a zero matrix, which is the case for G𝒢δ+𝒢~δ𝐺subscriptsuperscript𝒢𝛿subscript~𝒢𝛿G\in\mathcal{G}^{+}_{\delta}\setminus\widetilde{\mathcal{G}}_{\delta}italic_G ∈ caligraphic_G start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT ∖ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT, we have uX=𝟎Lsuperscript𝑢top𝑋superscriptsubscript0𝐿topu^{\top}X=\mathbf{0}_{L}^{\top}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X = bold_0 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Then uXuf𝒯,c2(G¯):,j𝟙L<tl𝟙Lsuperscript𝑢top𝑋superscript𝑢topsubscript𝑓𝒯𝑐2subscript¯𝐺:𝑗superscriptsubscript1𝐿topsubscript𝑡𝑙subscript1𝐿u^{\top}X-u^{\top}f_{\mathcal{T},c2}({\bar{G}})_{:,j}\mathds{1}_{L}^{\top}<-t_% {l}\mathds{1}_{L}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X - italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT < - italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT. Since tl>δ/2subscript𝑡𝑙𝛿2t_{l}>\delta/2italic_t start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT > italic_δ / 2, the output remains the same as X𝑋Xitalic_X.

  • Case for G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. Consider the input X𝑋Xitalic_X is f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ), where G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT is not equal to G¯¯𝐺\bar{G}over¯ start_ARG italic_G end_ARG. According to Property 2 of Lemma E.4, given a j[L]𝑗delimited-[]𝐿j\in[L]italic_j ∈ [ italic_L ], uf𝒯,c2(G):,k,(k[L])superscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑘𝑘delimited-[]𝐿u^{\top}f_{\mathcal{T},c2}(G)_{:,k},(k\in[L])italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_k end_POSTSUBSCRIPT , ( italic_k ∈ [ italic_L ] ) differs from uf𝒯,c2(G¯):,jsuperscript𝑢topsubscript𝑓𝒯𝑐2subscript¯𝐺:𝑗u^{\top}f_{\mathcal{T},c2}({\bar{G}})_{:,j}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT by at least δ𝛿\deltaitalic_δ. Then we have

    ζ5(uf𝒯,c2(G)uf𝒯,c2(G¯):,j𝟙L)=𝟎L.subscript𝜁5superscript𝑢topsubscript𝑓𝒯𝑐2𝐺superscript𝑢topsubscript𝑓𝒯𝑐2subscript¯𝐺:𝑗superscriptsubscript1𝐿topsuperscriptsubscript0𝐿top\zeta_{5}(u^{\top}f_{\mathcal{T},c2}(G)-u^{\top}f_{\mathcal{T},c2}({\bar{G}})_% {:,j}\mathds{1}_{L}^{\top})=\mathbf{0}_{L}^{\top}.italic_ζ start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) - italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) = bold_0 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .

    Thus the input is left untouched.

    If G=G¯𝐺¯𝐺G=\bar{G}italic_G = over¯ start_ARG italic_G end_ARG, then

    ζ5(uf𝒯,c2(G)uf𝒯,c2(G¯):,j𝟙L)=(ej).subscript𝜁5superscript𝑢topsubscript𝑓𝒯𝑐2𝐺superscript𝑢topsubscript𝑓𝒯𝑐2subscript¯𝐺:𝑗superscriptsubscript1𝐿topsuperscriptsubscript𝑒𝑗top\zeta_{5}(u^{\top}f_{\mathcal{T},c2}(G)-u^{\top}f_{\mathcal{T},c2}({\bar{G}})_% {:,j}\mathds{1}_{L}^{\top})=(e_{j})^{\top}.italic_ζ start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) - italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) = ( italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .

    Thus we shift the j𝑗jitalic_j-th column of f𝒯,c2(G)subscript𝑓𝒯𝑐2𝐺f_{\mathcal{T},c2}(G)italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) to

    f𝒯,c2(G):,j+((AG¯):,jf𝒯,c2(G¯):,j)=f𝒯,c2(G):,j+((AG):,jf𝒯,c2(G):,j)=(AG):,j.subscript𝑓𝒯𝑐2subscript𝐺:𝑗subscriptsubscript𝐴¯𝐺:𝑗subscript𝑓𝒯𝑐2subscript¯𝐺:𝑗subscript𝑓𝒯𝑐2subscript𝐺:𝑗subscriptsubscript𝐴𝐺:𝑗subscript𝑓𝒯𝑐2subscript𝐺:𝑗subscriptsubscript𝐴𝐺:𝑗\displaystyle f_{\mathcal{T},c2}(G)_{:,j}+((A_{\bar{G}})_{:,j}-f_{\mathcal{T},% c2}({\bar{G}})_{:,j})=f_{\mathcal{T},c2}(G)_{:,j}+((A_{G})_{:,j}-f_{\mathcal{T% },c2}(G)_{:,j})=(A_{G})_{:,j}.italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT + ( ( italic_A start_POSTSUBSCRIPT over¯ start_ARG italic_G end_ARG end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_G end_ARG ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ) = italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT + ( ( italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT ) = ( italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT .

In other word, this layer maps the column f𝒯,c2(G):,jsubscript𝑓𝒯𝑐2subscript𝐺:𝑗f_{\mathcal{T},c2}(G)_{:,j}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT to (AG):,jsubscriptsubscript𝐴𝐺:𝑗(A_{G})_{:,j}( italic_A start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT, without affecting any other columns.

We defer from above that we need one layer per each unique value of uf𝒯,c2(G):,jsuperscript𝑢topsubscript𝑓𝒯𝑐2subscript𝐺:𝑗u^{\top}f_{\mathcal{T},c2}(G)_{:,j}italic_u start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c 2 end_POSTSUBSCRIPT ( italic_G ) start_POSTSUBSCRIPT : , italic_j end_POSTSUBSCRIPT for each G𝒢~δ𝐺subscript~𝒢𝛿G\in\widetilde{\mathcal{G}}_{\delta}italic_G ∈ over~ start_ARG caligraphic_G end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT. Note that there are 𝒪(δdL)𝒪superscript𝛿𝑑𝐿\mathcal{O}(\delta^{-dL})caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT ) such numbers, so we use 𝒪(δdL)𝒪superscript𝛿𝑑𝐿\mathcal{O}(\delta^{-dL})caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - italic_d italic_L end_POSTSUPERSCRIPT ) layers to finish our construction. ∎

E.5.5 Proof of Lemma E.7
Proof of Lemma E.7.

We restate the proof from (Yun et al., 2020) for completeness.

The proof follows two steps: (i) Approximate the modified self-attention layers. (ii) Approximate the modified feed-forward layers.

  • Step 1: Approximate the Modified Self-Attention Layers.

    We achieve this by approximating the SoftmaxSoftmax\mathop{\rm{Softmax}}roman_Softmax operator σSsubscript𝜎𝑆\sigma_{S}italic_σ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT with the HardmaxHardmax\mathop{\rm{Hardmax}}roman_Hardmax operator σHsubscript𝜎𝐻\sigma_{H}italic_σ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT. Given a matrix Xd×L𝑋superscript𝑑𝐿X\in\mathbb{R}^{d\times L}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, we have

    σS(λX)σH(X),asλ.formulae-sequencesubscript𝜎𝑆𝜆𝑋subscript𝜎𝐻𝑋as𝜆\sigma_{S}(\lambda X)\rightarrow\sigma_{H}(X),\quad\text{as}\quad\lambda% \rightarrow\infty.italic_σ start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT ( italic_λ italic_X ) → italic_σ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ( italic_X ) , as italic_λ → ∞ .

    The operator is the only difference between the normal and the modified self-attention layers. We approximate the modified self-attention layer in 𝒯¯pr,m,lsuperscriptsubscript¯𝒯𝑝𝑟𝑚𝑙\bar{\mathcal{T}}_{p}^{r,m,l}over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT by the normal self-attention layer with the same number of heads r𝑟ritalic_r and head size m𝑚mitalic_m.

  • Step2: Approximate the Modified Feed-Forward Layers.

    We achieve this by approximating the activation function in ΨΨ\Psiroman_Ψ with four ReLUReLU{\rm ReLU}roman_ReLU functions. From Definition E.3, we recall that ΨΨ\Psiroman_Ψ denotes three-piecewise functions with at least a constant piece. We consider the following ζΨ𝜁Ψ\zeta\in\Psiitalic_ζ ∈ roman_Ψ:

    ζ(x)={b1 if x<c1,a2x+b2 if c1x<c2,a3x+b3 if c2x,𝜁𝑥casessubscript𝑏1 if 𝑥subscript𝑐1subscript𝑎2𝑥subscript𝑏2 if subscript𝑐1𝑥subscript𝑐2subscript𝑎3𝑥subscript𝑏3 if subscript𝑐2𝑥\zeta(x)=\begin{cases}b_{1}&\text{ if }x<c_{1},\\ a_{2}x+b_{2}&\text{ if }c_{1}\leq x<c_{2},\\ a_{3}x+b_{3}&\text{ if }c_{2}\leq x,\end{cases}italic_ζ ( italic_x ) = { start_ROW start_CELL italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL if italic_x < italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL if italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_x < italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_CELL start_CELL if italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_x , end_CELL end_ROW

    where a2,a3,b1,b2,b3,c1,c2subscript𝑎2subscript𝑎3subscript𝑏1subscript𝑏2subscript𝑏3subscript𝑐1subscript𝑐2a_{2},a_{3},b_{1},b_{2},b_{3},c_{1},c_{2}\in\mathbb{R}italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R, and c1<c2subscript𝑐1subscript𝑐2c_{1}<c_{2}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

    We approximate ζ(x)𝜁𝑥\zeta(x)italic_ζ ( italic_x ) by ζ~(x)~𝜁𝑥\widetilde{\zeta}(x)over~ start_ARG italic_ζ end_ARG ( italic_x ) composed of four ReLUReLU{\rm ReLU}roman_ReLU functions:

    ζ~(x)=~𝜁𝑥absent\displaystyle\widetilde{\zeta}(x)=over~ start_ARG italic_ζ end_ARG ( italic_x ) = b1+a2c1+b2b1ϵReLU(xc1+ϵ)+(a2a2c1+b2b1ϵ)ReLU(xc1)subscript𝑏1subscript𝑎2subscript𝑐1subscript𝑏2subscript𝑏1italic-ϵReLUxsubscriptc1italic-ϵsubscripta2subscripta2subscriptc1subscriptb2subscriptb1italic-ϵReLUxsubscriptc1\displaystyle b_{1}+\frac{a_{2}c_{1}+b_{2}-b_{1}}{\epsilon}\rm{ReLU}(x-c_{1}+% \epsilon)+\left(a_{2}-\frac{a_{2}c_{1}+b_{2}-b_{1}}{\epsilon}\right)\rm{ReLU}(% x-c_{1})italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_ϵ end_ARG roman_ReLU ( roman_x - roman_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_ϵ ) + ( roman_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - divide start_ARG roman_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + roman_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - roman_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_ϵ end_ARG ) roman_ReLU ( roman_x - roman_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
    +(a3c2+b3a2(c2ϵ)b2ϵa2)ReLU(xc2+ϵ)subscript𝑎3subscript𝑐2subscript𝑏3subscript𝑎2subscript𝑐2italic-ϵsubscript𝑏2italic-ϵsubscript𝑎2ReLUxsubscriptc2italic-ϵ\displaystyle+\left(\frac{a_{3}c_{2}+b_{3}-a_{2}(c_{2}-\epsilon)-b_{2}}{% \epsilon}-a_{2}\right)\rm{ReLU}(x-c_{2}+\epsilon)+ ( divide start_ARG italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_ϵ ) - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_ϵ end_ARG - italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) roman_ReLU ( roman_x - roman_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_ϵ )
    +(a3a3c2+b3a2(c2ϵ)b2ϵ)ReLU(xc2)subscript𝑎3subscript𝑎3subscript𝑐2subscript𝑏3subscript𝑎2subscript𝑐2italic-ϵsubscript𝑏2italic-ϵReLUxsubscriptc2\displaystyle+\left(a_{3}-\frac{a_{3}c_{2}+b_{3}-a_{2}(c_{2}-\epsilon)-b_{2}}{% \epsilon}\right)\rm{ReLU}(x-c_{2})+ ( italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT - divide start_ARG italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_ϵ ) - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_ϵ end_ARG ) roman_ReLU ( roman_x - roman_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )
    =\displaystyle== {b1 if x<c1ϵ,(a2c1+b2b1)(xc1)/ϵ+a2c1+b2 if c1ϵx<c1,a2x+b2 if c1x<c2ϵ,(a3c2+b3a2(c2ϵ)b2)(xc2)/ϵ+a3c2+b3 if c2ϵx<c2,a3x+b3 if c2x.casessubscript𝑏1 if 𝑥subscript𝑐1italic-ϵsubscript𝑎2subscript𝑐1subscript𝑏2subscript𝑏1𝑥subscript𝑐1italic-ϵsubscript𝑎2subscript𝑐1subscript𝑏2 if subscript𝑐1italic-ϵ𝑥subscript𝑐1subscript𝑎2𝑥subscript𝑏2 if subscript𝑐1𝑥subscript𝑐2italic-ϵsubscript𝑎3subscript𝑐2subscript𝑏3subscript𝑎2subscript𝑐2italic-ϵsubscript𝑏2𝑥subscript𝑐2italic-ϵsubscript𝑎3subscript𝑐2subscript𝑏3 if subscript𝑐2italic-ϵ𝑥subscript𝑐2subscript𝑎3𝑥subscript𝑏3 if subscript𝑐2𝑥\displaystyle\begin{cases}b_{1}&\text{ if }x<c_{1}-\epsilon,\\ (a_{2}c_{1}+b_{2}-b_{1})(x-c_{1})/\epsilon+a_{2}c_{1}+b_{2}&\text{ if }c_{1}-% \epsilon\leq x<c_{1},\\ a_{2}x+b_{2}&\text{ if }c_{1}\leq x<c_{2}-\epsilon,\\ (a_{3}c_{2}+b_{3}-a_{2}(c_{2}-\epsilon)-b_{2})(x-c_{2})/\epsilon+a_{3}c_{2}+b_% {3}&\text{ if }c_{2}-\epsilon\leq x<c_{2},\\ a_{3}x+b_{3}&\text{ if }c_{2}\leq x.\end{cases}{ start_ROW start_CELL italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL if italic_x < italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_ϵ , end_CELL end_ROW start_ROW start_CELL ( italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( italic_x - italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) / italic_ϵ + italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL if italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_ϵ ≤ italic_x < italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL if italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_x < italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_ϵ , end_CELL end_ROW start_ROW start_CELL ( italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT - italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_ϵ ) - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( italic_x - italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) / italic_ϵ + italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_CELL start_CELL if italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_ϵ ≤ italic_x < italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_CELL start_CELL if italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_x . end_CELL end_ROW

    As ϵ0italic-ϵ0\epsilon\rightarrow 0italic_ϵ → 0, we approximate ζ(x)𝜁𝑥\zeta(x)italic_ζ ( italic_x ) using ζ~(x)~𝜁𝑥\widetilde{\zeta}(x)over~ start_ARG italic_ζ end_ARG ( italic_x ). The activation function is the only difference between the normal and modified feed-forward layers. We approximate the modified feed-forward layer in 𝒯¯pr,m,lsuperscriptsubscript¯𝒯𝑝𝑟𝑚𝑙\bar{\mathcal{T}}_{p}^{r,m,l}over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT by the normal one.

    Thus, for any f𝒯,c𝒯¯p2,1,1subscript𝑓𝒯𝑐superscriptsubscript¯𝒯𝑝211f_{\mathcal{T},c}\in\bar{\mathcal{T}}_{p}^{2,1,1}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c end_POSTSUBSCRIPT ∈ over¯ start_ARG caligraphic_T end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 1 end_POSTSUPERSCRIPT, there exists a function f𝒯𝒯p2,1,4subscript𝑓𝒯superscriptsubscript𝒯𝑝214f_{\mathcal{T}}\in\mathcal{T}_{p}^{2,1,4}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT to approximate f𝒯,csubscript𝑓𝒯𝑐f_{\mathcal{T},c}italic_f start_POSTSUBSCRIPT caligraphic_T , italic_c end_POSTSUBSCRIPT.

This completes the proof. ∎

Appendix F Proofs of Section 3

Our proof is motivated by the approximation and estimation theory of U-Net-based diffusion models in (Chen et al., 2023a). We use the universal approximation capability Appendix E and the covering number of transformer networks to proceed with our proof. Specifically, we derive the approximation error bound in Section F.1 and the corresponding sample complexity bound in Section F.2. Then we show that the data distribution generated from the estimated score function converges toward a proximate area of the original one in Section F.3.

F.1 Proof of Theorem 3.1

Here we present some auxiliary theoretical results in Section F.1.1 to prepare our main proof of Theorem 3.1. Then we derive the approximation error bound of DiTs (i.e., the proof of Theorem 3.1) in Section F.1.2.

F.1.1 Auxiliary Lemmas for Theorem 3.1.

We restate some auxiliary lemmas and their proofs here from (Chen et al., 2023a) for later convenience.

Lemma F.1 (Lemma 16 of (Chen et al., 2023a)).

Consider a probability density function ph(h)=exp(Ch22/2)subscript𝑝𝐶superscriptsubscriptnorm222p_{h}(h)=\exp(-C\norm{h}_{2}^{2}/2)italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) = roman_exp ( start_ARG - italic_C ∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 end_ARG ) for hd0superscriptsubscript𝑑0h\in\mathbb{R}^{d_{0}}italic_h ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and constant C>0𝐶0C>0italic_C > 0. Let rh>0subscript𝑟0r_{h}>0italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT > 0 be a fixed radius. Then it holds

h2>rhph(h)dh2d0πd0/2CΓ(d0/2+1)rhd02exp(Crh2/2),subscriptsubscriptnorm2subscript𝑟subscript𝑝2subscript𝑑0superscript𝜋subscript𝑑02𝐶Γsubscript𝑑021superscriptsubscript𝑟subscript𝑑02𝐶superscriptsubscript𝑟22\displaystyle\int_{\norm{h}_{2}>r_{h}}p_{h}(h)\differential h\leq\frac{2d_{0}% \pi^{d_{0}/2}}{C\Gamma(d_{0}/2+1)}r_{h}^{d_{0}-2}\exp(-Cr_{h}^{2}/2),∫ start_POSTSUBSCRIPT ∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h ≤ divide start_ARG 2 italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_C roman_Γ ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / 2 + 1 ) end_ARG italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - 2 end_POSTSUPERSCRIPT roman_exp ( start_ARG - italic_C italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 end_ARG ) ,
h2>rhh22ph(h)dh2d0πd0/2CΓ(d0/2+1)rhd0exp(Crh2/2).subscriptsubscriptnorm2subscript𝑟superscriptsubscriptnorm22subscript𝑝2subscript𝑑0superscript𝜋subscript𝑑02𝐶Γsubscript𝑑021superscriptsubscript𝑟subscript𝑑0𝐶superscriptsubscript𝑟22\displaystyle\int_{\norm{h}_{2}>r_{h}}\norm{h}_{2}^{2}p_{h}(h)\differential h% \leq\frac{2d_{0}\pi^{d_{0}/2}}{C\Gamma(d_{0}/2+1)}r_{h}^{d_{0}}\exp(-Cr_{h}^{2% }/2).∫ start_POSTSUBSCRIPT ∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h ≤ divide start_ARG 2 italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_π start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_C roman_Γ ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / 2 + 1 ) end_ARG italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( start_ARG - italic_C italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 end_ARG ) .
Lemma F.2 (Lemma 2 of (Chen et al., 2023a)).

Suppose Assumption 2.2 holds and g𝑔gitalic_g is defined as:

q(h¯,t)=hψt(h¯|h)ph(h)ψt(h¯|h)ph(h)dhdh,h¯=Bx.formulae-sequence𝑞¯𝑡subscript𝜓𝑡conditional¯subscript𝑝subscript𝜓𝑡conditional¯subscript𝑝¯superscript𝐵top𝑥\displaystyle q(\bar{h},t)=\int\frac{h\psi_{t}(\bar{h}|h)p_{h}(h)}{\int\psi_{t% }(\bar{h}|h)p_{h}(h)\differential h}\differential h,\quad\bar{h}=B^{\top}x.italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) = ∫ divide start_ARG italic_h italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG | italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) end_ARG start_ARG ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG | italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h end_ARG start_DIFFOP roman_d end_DIFFOP italic_h , over¯ start_ARG italic_h end_ARG = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x .

Given ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, with rh=c(d0log(d0/T0)+log(1/ϵ))subscript𝑟𝑐subscript𝑑0subscript𝑑0subscript𝑇01italic-ϵr_{h}=c\left(\sqrt{d_{0}\log(d_{0}/T_{0})+\log(1/\epsilon)}\right)italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = italic_c ( square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) + roman_log ( start_ARG 1 / italic_ϵ end_ARG ) end_ARG ) for an absolute constant c𝑐citalic_c, it holds

q(h¯,t)𝟙{h¯2rh}L2(Pt)ϵ,fort[T0,T].formulae-sequencesubscriptnorm𝑞¯𝑡1subscriptnorm¯2subscript𝑟superscript𝐿2subscript𝑃𝑡italic-ϵfor𝑡subscript𝑇0𝑇\displaystyle\norm{q(\bar{h},t)\mathds{1}\{\norm{\bar{h}}_{2}\geq r_{h}\}}_{L^% {2}(P_{t})}\leq\epsilon,\leavevmode\nobreak\ \text{for}\leavevmode\nobreak\ t% \in[T_{0},T].∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) blackboard_1 { ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ≤ italic_ϵ , for italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] .
Lemma F.3 (Theorem 1 of (Chen et al., 2023a)).

We denote

τ(rh)=supt[T0,T]suph¯[0,rh]dtq(h¯,t)2.𝜏subscript𝑟subscriptsupremum𝑡subscript𝑇0𝑇subscriptsupremum¯superscript0subscript𝑟𝑑subscriptnorm𝑡𝑞¯𝑡2\displaystyle\tau(r_{h})=\sup_{t\in[T_{0},T]}\sup_{\bar{h}\in[0,r_{h}]^{d}}% \norm{\frac{\partial}{\partial t}q(\bar{h},t)}_{2}.italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = roman_sup start_POSTSUBSCRIPT italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT over¯ start_ARG italic_h end_ARG ∈ [ 0 , italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_ARG divide start_ARG ∂ end_ARG start_ARG ∂ italic_t end_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

With q(h¯,t)=hψt(h¯|h)ph(h)/(ψt(h¯|h)ph(h)dh)dh𝑞¯𝑡subscript𝜓𝑡conditional¯subscript𝑝subscript𝜓𝑡conditional¯subscript𝑝q(\bar{h},t)=\int h\psi_{t}(\bar{h}|h)p_{h}(h)/(\int\psi_{t}(\bar{h}|h)p_{h}(h% )\differential h)\differential hitalic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) = ∫ italic_h italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG | italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) / ( ∫ italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG | italic_h ) italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h ) start_DIFFOP roman_d end_DIFFOP italic_h and phsubscript𝑝p_{h}italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT satisfies Assumption 2.2, we have a coarse upper bound for τ(rh)𝜏subscript𝑟\tau(r_{h})italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT )

τ(rh)=𝒪(1+β2(t)β(t)(Ls++1σ(t))d0rh)=𝒪(eT/2Ls+rhd0).𝜏subscript𝑟𝒪1superscript𝛽2𝑡𝛽𝑡subscript𝐿subscript𝑠1𝜎𝑡subscript𝑑0subscript𝑟𝒪superscript𝑒𝑇2subscript𝐿subscript𝑠subscript𝑟subscript𝑑0\displaystyle\tau(r_{h})=\mathcal{O}\left(\frac{1+\beta^{2}(t)}{\beta(t)}\left% (L_{s_{+}}+\frac{1}{\sigma(t)}\right)\sqrt{d_{0}}r_{h}\right)=\mathcal{O}\left% (e^{T/2}L_{s_{+}}r_{h}\sqrt{d_{0}}\right).italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = caligraphic_O ( divide start_ARG 1 + italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG italic_β ( italic_t ) end_ARG ( italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ) square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = caligraphic_O ( italic_e start_POSTSUPERSCRIPT italic_T / 2 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .
Lemma F.4 (Lemma 10 of (Chen et al., 2020b)).

For any given ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, and L𝐿Litalic_L-Lipschitz function g𝑔gitalic_g defined on [0,1]d0superscript01subscript𝑑0[0,1]^{d_{0}}[ 0 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, there exists a continuous function f¯¯𝑓\bar{f}over¯ start_ARG italic_f end_ARG constructed by trapezoid function that

gf¯ϵ.subscriptnorm𝑔¯𝑓italic-ϵ\displaystyle\norm{g-\bar{f}}_{\infty}\leq\epsilon.∥ start_ARG italic_g - over¯ start_ARG italic_f end_ARG end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_ϵ .

Moreover, the Lipschitz continuity of \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}roman_Δ 111 italic_f is bounded by

|\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(x)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(y)|10d0Lxy2for anyx,y[0,1]d0.formulae-sequence\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓𝑥\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓𝑦10subscript𝑑0𝐿subscriptnorm𝑥𝑦2for any𝑥𝑦superscript01subscript𝑑0\displaystyle\left\lvert\macc@depth\char 1\relax\frozen@everymath{\macc@group}% \macc@set@skewchar\macc@nested@a 111{f}(x)-\macc@depth\char 1\relax% \frozen@everymath{\macc@group}\macc@set@skewchar\macc@nested@a 111{f}(y)\right% \rvert\leq 10d_{0}L\norm{x-y}_{2}\quad\text{for any}\quad x,y\in[0,1]^{d_{0}}.| roman_Δ 111 italic_f ( italic_x ) - roman_Δ 111 italic_f ( italic_y ) | ≤ 10 italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L ∥ start_ARG italic_x - italic_y end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT for any italic_x , italic_y ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT .
F.1.2 Main Proof of Theorem 3.1
Proof of Theorem 3.1.

With logpth(h¯)=Bs+(h¯,t)superscriptsubscript𝑝𝑡¯superscript𝐵topsubscript𝑠¯𝑡\nabla\log p_{t}^{h}\left(\bar{h}\right)=B^{\top}s_{+}(\bar{h},t)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( over¯ start_ARG italic_h end_ARG ) = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ), we note that in (2.4)

q(h¯,t)=σ(t)logpth(h¯)+Bx=σ(t)B(s+(h¯,t)+x).𝑞¯𝑡𝜎𝑡superscriptsubscript𝑝𝑡¯superscript𝐵top𝑥𝜎𝑡superscript𝐵topsubscript𝑠¯𝑡𝑥\displaystyle q(\bar{h},t)=\sigma(t)\nabla\log p_{t}^{h}\left(\bar{h}\right)+B% ^{\top}x=\sigma(t)B^{\top}(s_{+}(\bar{h},t)+x).italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) = italic_σ ( italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( over¯ start_ARG italic_h end_ARG ) + italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x = italic_σ ( italic_t ) italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) + italic_x ) . (F.1)

We proceed as follows:

  • Step 1. Approximate q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) with a compact-supported continuous function \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ).

  • Step 2. Approximate \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) with a Transformer network.

Step 1. Approximate q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) with a Compact-supported Continuous Function \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ). Here we partition d0superscriptsubscript𝑑0\mathbb{R}^{d_{0}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT into a compact subset H1:={h¯|h¯2rh}assignsubscript𝐻1conditional-set¯subscriptnorm¯2subscript𝑟H_{1}:=\{\bar{h}|\norm{\bar{h}}_{2}\leq r_{h}\}italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT := { over¯ start_ARG italic_h end_ARG | ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } and its complement H2subscript𝐻2H_{2}italic_H start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, where rhsubscript𝑟r_{h}italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is to be determined later. We approximate q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) on the two subset respectively, and then prove f¯¯𝑓\bar{f}over¯ start_ARG italic_f end_ARG’s continuity. Such a step achieves an estimation error of d0ϵsubscript𝑑0italic-ϵ\sqrt{d_{0}}\epsilonsquare-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG italic_ϵ between q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) and f¯(h¯,t)¯𝑓¯𝑡\bar{f}(\bar{h},t)over¯ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG , italic_t ). We show the main proof here.

  • Approximation on H2×[T0,T]subscript𝐻2subscript𝑇0𝑇H_{2}\times[T_{0},T]italic_H start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ]. For any ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, we take rh=c(d0log(d0/T0)logϵ)subscript𝑟𝑐subscript𝑑0subscript𝑑0subscript𝑇0italic-ϵr_{h}=c(\sqrt{d_{0}\log(d_{0}/T_{0})-\log\epsilon})italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = italic_c ( square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) - roman_log italic_ϵ end_ARG ). We obtain from Lemma F.2 that

    q(h¯,t)𝟙{h¯2rh}L2(Pt)ϵfort[T0,T].formulae-sequencesubscriptnorm𝑞¯𝑡1subscriptnorm¯2subscript𝑟superscript𝐿2subscript𝑃𝑡italic-ϵfor𝑡subscript𝑇0𝑇\displaystyle\norm{q(\bar{h},t)\mathds{1}\{\norm{\bar{h}}_{2}\geq r_{h}\}}_{L^% {2}(P_{t})}\leq\epsilon\quad\text{for}\quad t\in[T_{0},T].∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) blackboard_1 { ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ≤ italic_ϵ for italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] .

    So we set f¯(h¯,t)=0¯𝑓¯𝑡0\bar{f}(\bar{h},t)=0over¯ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG , italic_t ) = 0 on H2×[T0,T]subscript𝐻2subscript𝑇0𝑇H_{2}\times[T_{0},T]italic_H start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ].

  • Approximation on H1×[T0,T]subscript𝐻1subscript𝑇0𝑇H_{1}\times[T_{0},T]italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ]. On H1×[T0,T]subscript𝐻1subscript𝑇0𝑇H_{1}\times[T_{0},T]italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ], we approximate q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) by each coordinate qk(h¯,t)subscript𝑞𝑘¯𝑡q_{k}(\bar{h},t)italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) respectively, where q(h¯,t)=[q1(h¯,t),q2(h¯,t),,qd0(h¯,t)]𝑞¯𝑡subscript𝑞1¯𝑡subscript𝑞2¯𝑡subscript𝑞subscript𝑑0¯𝑡q(\bar{h},t)=[q_{1}(\bar{h},t),q_{2}(\bar{h},t),\cdots,q_{d_{0}}(\bar{h},t)]italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) = [ italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) , italic_q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) , ⋯ , italic_q start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) ]. We firstly rescale the input by y=(h¯+rh𝟙)/2rhsuperscript𝑦¯subscript𝑟12subscript𝑟y^{\prime}=(\bar{h}+r_{h}\mathds{1})/2r_{h}italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ( over¯ start_ARG italic_h end_ARG + italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT blackboard_1 ) / 2 italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT and t=t/Tsuperscript𝑡𝑡𝑇t^{\prime}=t/Titalic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_t / italic_T, so that the transformed input space is [0,1]d0×[T0/T,1]superscript01subscript𝑑0subscript𝑇0𝑇1[0,1]^{d_{0}}\times[T_{0}/T,1][ 0 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T , 1 ]. We implement such a transformation by a single feed-forward layer.

    By Assumption 2.3, on-support score s+(h¯,t)subscript𝑠¯𝑡s_{+}(\bar{h},t)italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) is Ls+subscript𝐿subscript𝑠L_{s_{+}}italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT-Lipschitz in h¯¯\bar{h}over¯ start_ARG italic_h end_ARG. This implies q(h¯,t)𝑞¯𝑡q(\bar{h},t)italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) is (1+Ls+)1subscript𝐿subscript𝑠(1+L_{s_{+}})( 1 + italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT )-Lipschitz in h¯¯\bar{h}over¯ start_ARG italic_h end_ARG. When taking the transformed inputs, g(y,t)=q(2rhyrh𝟙,Tt)𝑔superscript𝑦superscript𝑡𝑞2subscript𝑟superscript𝑦subscript𝑟1𝑇superscript𝑡g(y^{\prime},t^{\prime})=q(2r_{h}y^{\prime}-r_{h}\mathds{1},Tt^{\prime})italic_g ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = italic_q ( 2 italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT blackboard_1 , italic_T italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) becomes 2rh(1+Ls+)2subscript𝑟1subscript𝐿subscript𝑠2r_{h}(1+L_{s_{+}})2 italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( 1 + italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT )-Lipschitz in ysuperscript𝑦y^{\prime}italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT; so is each coordinate gk(y,t)subscript𝑔𝑘superscript𝑦𝑡g_{k}(y^{\prime},t)italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t ). Here we take Lh=1+Ls+subscript𝐿1subscript𝐿subscript𝑠L_{h}=1+L_{s_{+}}italic_L start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 1 + italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT.

    Besides, g(y,t)𝑔superscript𝑦superscript𝑡g(y^{\prime},t^{\prime})italic_g ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) is Tτ(rh)𝑇𝜏subscript𝑟T\tau(r_{h})italic_T italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT )-Lipsichitz with respect to t𝑡titalic_t, where

    τ(rh)=supt[T0,T]suph¯[0,rh]dtq(h¯,t)2.𝜏subscript𝑟subscriptsupremum𝑡subscript𝑇0𝑇subscriptsupremum¯superscript0subscript𝑟𝑑subscriptnorm𝑡𝑞¯𝑡2\displaystyle\tau(r_{h})=\sup_{t\in[T_{0},T]}\sup_{\bar{h}\in[0,r_{h}]^{d}}% \norm{\frac{\partial}{\partial t}q(\bar{h},t)}_{2}.italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = roman_sup start_POSTSUBSCRIPT italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT over¯ start_ARG italic_h end_ARG ∈ [ 0 , italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_ARG divide start_ARG ∂ end_ARG start_ARG ∂ italic_t end_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

    We have a coarse upper bound for τ(rh)𝜏subscript𝑟\tau(r_{h})italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) in Lemma F.3. We repeat it here for convenience

    τ(rh)=𝒪(1+β2(t)β(t)(Ls++1σ(t))d0rh)=𝒪(eT/2Ls+rhd0).𝜏subscript𝑟𝒪1superscript𝛽2𝑡𝛽𝑡subscript𝐿subscript𝑠1𝜎𝑡subscript𝑑0subscript𝑟𝒪superscript𝑒𝑇2subscript𝐿subscript𝑠subscript𝑟subscript𝑑0\displaystyle\tau(r_{h})=\mathcal{O}\left(\frac{1+\beta^{2}(t)}{\beta(t)}\left% (L_{s_{+}}+\frac{1}{\sigma(t)}\right)\sqrt{d_{0}}r_{h}\right)=\mathcal{O}\left% (e^{T/2}L_{s_{+}}r_{h}\sqrt{d_{0}}\right).italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = caligraphic_O ( divide start_ARG 1 + italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG italic_β ( italic_t ) end_ARG ( italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ) square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = caligraphic_O ( italic_e start_POSTSUPERSCRIPT italic_T / 2 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .

    In conclusion, each gk(y,t)subscript𝑔𝑘superscript𝑦𝑡g_{k}(y^{\prime},t)italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t ) is Lipsichitz continuous. So we can apply Lemma F.4 to find out f¯k(y,t)subscript¯𝑓𝑘superscript𝑦𝑡\bar{f}_{k}(y^{\prime},t)over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t ) for approximating each coordinate. We concatenate \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111fi\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111subscript𝑓𝑖\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}_{i}roman_Δ 111 italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s together and construct \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f=[\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f1,,\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111fd0]\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓superscript\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111subscript𝑓1\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111subscript𝑓subscript𝑑0top\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}=[\macc@depth\char 1\relax\frozen@everymath{\macc@group}% \macc@set@skewchar\macc@nested@a 111{f}_{1},\dots,\macc@depth\char 1\relax% \frozen@everymath{\macc@group}\macc@set@skewchar\macc@nested@a 111{f}_{d_{0}}]% ^{\top}roman_Δ 111 italic_f = [ roman_Δ 111 italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , roman_Δ 111 italic_f start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. According to the construction in Lemma F.4, for any given ϵitalic-ϵ\epsilonitalic_ϵ, we achieve

    supy,t[0,1]d×[T0/T,1]\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(y,t)g(y,t)ϵ,subscriptsupremumsuperscript𝑦superscript𝑡superscript01𝑑subscript𝑇0𝑇1subscriptnorm\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓superscript𝑦superscript𝑡𝑔superscript𝑦superscript𝑡italic-ϵ\displaystyle\sup_{y^{\prime},t^{\prime}\in[0,1]^{d}\times[T_{0}/T,1]}\norm{% \macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(y^{\prime},t^{\prime})-g(y^{\prime},t^{\prime})}_{\infty% }\leq\epsilon,roman_sup start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T , 1 ] end_POSTSUBSCRIPT ∥ start_ARG roman_Δ 111 italic_f ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_g ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ italic_ϵ ,

    Considering the input rescaling (i.e., h¯y¯superscript𝑦\bar{h}\to y^{\prime}over¯ start_ARG italic_h end_ARG → italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and tt𝑡superscript𝑡t\to t^{\prime}italic_t → italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT), we obtain:

    • The constructed function is Lipschitz continuous in h¯¯\bar{h}over¯ start_ARG italic_h end_ARG, i.e., for any h¯1,h¯2H1subscript¯1subscript¯2subscript𝐻1\bar{h}_{1},\bar{h}_{2}\in H_{1}over¯ start_ARG italic_h end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over¯ start_ARG italic_h end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and t[T0,T]𝑡subscript𝑇0𝑇t\in[T_{0},T]italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ], it holds

      \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯1,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯2,t)10d0Lhh¯1h¯22.subscriptnorm\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓subscript¯1𝑡\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓subscript¯2𝑡10subscript𝑑0subscript𝐿subscriptnormsubscript¯1subscript¯22\displaystyle\norm{\macc@depth\char 1\relax\frozen@everymath{\macc@group}% \macc@set@skewchar\macc@nested@a 111{f}(\bar{h}_{1},t)-\macc@depth\char 1% \relax\frozen@everymath{\macc@group}\macc@set@skewchar\macc@nested@a 111{f}(% \bar{h}_{2},t)}_{\infty}\leq 10d_{0}L_{h}\norm{\bar{h}_{1}-\bar{h}_{2}}_{2}.∥ start_ARG roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t ) - roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ 10 italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∥ start_ARG over¯ start_ARG italic_h end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - over¯ start_ARG italic_h end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT . (F.2)
    • The function is also Lipschitz in t𝑡titalic_t, i.e., for any t1,t2[T0,T]subscript𝑡1subscript𝑡2subscript𝑇0𝑇t_{1},t_{2}\in[T_{0},T]italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] and h¯2rhsubscriptnorm¯2subscript𝑟\norm{\bar{h}}_{2}\leq r_{h}∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT, it holds

      \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t1)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t2)10τ(rh)t1t22.subscriptnorm\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯subscript𝑡1\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯subscript𝑡210𝜏subscript𝑟subscriptnormsubscript𝑡1subscript𝑡22\displaystyle\norm{\macc@depth\char 1\relax\frozen@everymath{\macc@group}% \macc@set@skewchar\macc@nested@a 111{f}(\bar{h},t_{1})-\macc@depth\char 1% \relax\frozen@everymath{\macc@group}\macc@set@skewchar\macc@nested@a 111{f}(% \bar{h},t_{2})}_{\infty}\leq 10\tau(r_{h})\norm{t_{1}-t_{2}}_{2}.∥ start_ARG roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ 10 italic_τ ( italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) ∥ start_ARG italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

    Due to the fact that the construction of \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) is based on trapezoid function, we have \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)=0\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡0\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)=0roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) = 0 for h¯2=rh,t[T0,T]formulae-sequencesubscriptnorm¯2subscript𝑟for-all𝑡subscript𝑇0𝑇\norm{\bar{h}}_{2}=r_{h},\forall t\in[T_{0},T]∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT , ∀ italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ]. So the two part of \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) can be joined together. To be more specific, the above Lipschitz continuity in h¯¯\bar{h}over¯ start_ARG italic_h end_ARG extends to the whole d0superscriptsubscript𝑑0\mathbb{R}^{d_{0}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

  • Approximation Error Analysis under L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT Norm. The L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT approximation error of \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}roman_Δ 111 italic_f can be decomposed into two terms:

    q(h¯,t)f¯(h¯,t)L2(Pth)=(q(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t))𝟙{h¯2<rh}L2(Pth)+q(h¯,t)𝟙{h¯2>rh}L2(Pth).subscriptnorm𝑞¯𝑡¯𝑓¯𝑡superscript𝐿2superscriptsubscript𝑃𝑡subscriptnorm𝑞¯𝑡\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡1subscriptnorm¯2subscript𝑟superscript𝐿2superscriptsubscript𝑃𝑡subscriptnorm𝑞¯𝑡1subscriptnorm¯2subscript𝑟superscript𝐿2superscriptsubscript𝑃𝑡\displaystyle\norm{q(\bar{h},t)-\bar{f}(\bar{h},t)}_{L^{2}(P_{t}^{h})}=\norm{(% q(\bar{h},t)-\macc@depth\char 1\relax\frozen@everymath{\macc@group}% \macc@set@skewchar\macc@nested@a 111{f}(\bar{h},t))\mathds{1}\{\norm{\bar{h}}_% {2}<r_{h}\}}_{L^{2}(P_{t}^{h})}+\norm{q(\bar{h},t)\mathds{1}\{\norm{\bar{h}}_{% 2}>r_{h}\}}_{L^{2}(P_{t}^{h})}.∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) - over¯ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT = ∥ start_ARG ( italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) - roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) ) blackboard_1 { ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT + ∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) blackboard_1 { ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT .

    The second term on the right-hand side above has already been bounded with the selection of rhsubscript𝑟r_{h}italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT:

    g(h¯,t)𝟙{h¯2>rh}L2(Pth)ϵ.subscriptnorm𝑔¯𝑡1subscriptnorm¯2subscript𝑟superscript𝐿2superscriptsubscript𝑃𝑡italic-ϵ\displaystyle\norm{g(\bar{h},t)\mathds{1}\{\norm{\bar{h}}_{2}>r_{h}\}}_{L^{2}(% P_{t}^{h})}\leq\epsilon.∥ start_ARG italic_g ( over¯ start_ARG italic_h end_ARG , italic_t ) blackboard_1 { ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ≤ italic_ϵ .

    The first term is bounded by:

    (q(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t))𝟙{h¯2<rh}L2(Pth)d0supy,t[0,1]d×[T0/T,1]\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(y,t)g(y,t)d0ϵ.subscriptnorm𝑞¯𝑡\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡1subscriptnorm¯2subscript𝑟superscript𝐿2superscriptsubscript𝑃𝑡subscript𝑑0subscriptsupremumsuperscript𝑦superscript𝑡superscript01𝑑subscript𝑇0𝑇1subscriptnorm\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓superscript𝑦superscript𝑡𝑔superscript𝑦superscript𝑡subscript𝑑0italic-ϵ\displaystyle\norm{(q(\bar{h},t)-\macc@depth\char 1\relax\frozen@everymath{% \macc@group}\macc@set@skewchar\macc@nested@a 111{f}(\bar{h},t))\mathds{1}\{% \norm{\bar{h}}_{2}<r_{h}\}}_{L^{2}(P_{t}^{h})}\leq\sqrt{d_{0}}\sup_{y^{\prime}% ,t^{\prime}\in[0,1]^{d}\times[T_{0}/T,1]}\norm{\macc@depth\char 1\relax% \frozen@everymath{\macc@group}\macc@set@skewchar\macc@nested@a 111{f}(y^{% \prime},t^{\prime})-g(y^{\prime},t^{\prime})}_{\infty}\leq\sqrt{d_{0}}\epsilon.∥ start_ARG ( italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) - roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) ) blackboard_1 { ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < italic_r start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT } end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ≤ square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG roman_sup start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T , 1 ] end_POSTSUBSCRIPT ∥ start_ARG roman_Δ 111 italic_f ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_g ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ≤ square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG italic_ϵ .

    So we obtain

    q(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)L2(Pth)(d0+1)ϵ.subscriptnorm𝑞¯𝑡\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡superscript𝐿2superscriptsubscript𝑃𝑡subscript𝑑01italic-ϵ\displaystyle\norm{q(\bar{h},t)-\macc@depth\char 1\relax\frozen@everymath{% \macc@group}\macc@set@skewchar\macc@nested@a 111{f}(\bar{h},t)}_{L^{2}(P_{t}^{% h})}\leq(\sqrt{d_{0}}+1)\epsilon.∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) - roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ≤ ( square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG + 1 ) italic_ϵ .

    If we substitute ϵitalic-ϵ\epsilonitalic_ϵ with ϵ/2italic-ϵ2\epsilon/2italic_ϵ / 2, we obtain that the approximation error of \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111f(h¯,t)\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111𝑓¯𝑡\macc@depth\char 1\relax\frozen@everymath{\macc@group}\macc@set@skewchar% \macc@nested@a 111{f}(\bar{h},t)roman_Δ 111 italic_f ( over¯ start_ARG italic_h end_ARG , italic_t ) is d0ϵsubscript𝑑0italic-ϵ\sqrt{d_{0}}\epsilonsquare-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG italic_ϵ.

Step 2. Approximate f¯(h¯,t)¯𝑓¯𝑡\bar{f}(\bar{h},t)over¯ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG , italic_t ) by a Transformer. This step is based on the universal approximation of transformers for the compact-supported continuous function in Lemma E.1. Following (Peebles and Xie, 2023), DiT uses time point t𝑡titalic_t to calculate the scale and shift value in the Transformer backbone, and it transforms a input picture into a sequential version. We ignore time point t𝑡titalic_t in the notation of Transformer network in DiT. Recall that the reshape layer R()𝑅R(\cdot)italic_R ( ⋅ ) in Definition 3.1, we consider use f():=R1f𝒯R()assign𝑓superscript𝑅1subscript𝑓𝒯𝑅f(\cdot):={R^{-1}\circ f_{\mathcal{T}}\circ R}(\cdot)italic_f ( ⋅ ) := italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∘ italic_R ( ⋅ ) to approximate f¯t():=f¯(,t)assignsubscript¯𝑓𝑡¯𝑓𝑡\bar{f}_{t}(\cdot):=\bar{f}(\cdot,t)over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) := over¯ start_ARG italic_f end_ARG ( ⋅ , italic_t ), where f𝒯𝒯p2,1,4subscript𝑓𝒯superscriptsubscript𝒯𝑝214f_{\mathcal{T}}\in\mathcal{T}_{p}^{2,1,4}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT.

  • Overall Approximation Error. With Lemma E.1, we approximate f¯t()subscript¯𝑓𝑡\bar{f}_{t}(\cdot)over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) with f^():=R1f^𝒯R()assign^𝑓superscript𝑅1subscript^𝑓𝒯𝑅\widehat{f}(\cdot):={R^{-1}\circ\widehat{f}_{\mathcal{T}}\circ R}(\cdot)over^ start_ARG italic_f end_ARG ( ⋅ ) := italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∘ italic_R ( ⋅ ), and denote

    H=R(h¯).𝐻𝑅¯\displaystyle H=R(\bar{h}).italic_H = italic_R ( over¯ start_ARG italic_h end_ARG ) .

    We have

    f¯t(h¯)f^(h¯)L2(Pth)subscriptnormsubscript¯𝑓𝑡¯^𝑓¯superscript𝐿2superscriptsubscript𝑃𝑡\displaystyle\norm{\bar{f}_{t}(\bar{h})-\widehat{f}(\bar{h})}_{L^{2}(P_{t}^{h})}∥ start_ARG over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG ) - over^ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT =(Pthf¯t(h¯)f^(h¯)22dh)1/2absentsuperscriptsubscriptsuperscriptsubscript𝑃𝑡superscriptsubscriptnormsubscript¯𝑓𝑡¯^𝑓¯2212\displaystyle=\left(\int_{P_{t}^{h}}\norm{\bar{f}_{t}(\bar{h})-\widehat{f}(% \bar{h})}_{2}^{2}\differential h\right)^{1/2}= ( ∫ start_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_ARG over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG ) - over^ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_h ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
    =(PthRf¯tR1(H)Rf^R1(H)F2dh)1/2absentsuperscriptsubscriptsuperscriptsubscript𝑃𝑡superscriptsubscriptnorm𝑅subscript¯𝑓𝑡superscript𝑅1𝐻𝑅^𝑓superscript𝑅1𝐻𝐹212\displaystyle=\left(\int_{P_{t}^{h}}\norm{R\circ\bar{f}_{t}\circ R^{-1}(H)-R% \circ\widehat{f}\circ R^{-1}(H)}_{F}^{2}\differential h\right)^{1/2}= ( ∫ start_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_ARG italic_R ∘ over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∘ italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_H ) - italic_R ∘ over^ start_ARG italic_f end_ARG ∘ italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_H ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_h ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
    =(PthRf¯tR1(H)f^𝒯(H)F2dh)1/2absentsuperscriptsubscriptsuperscriptsubscript𝑃𝑡superscriptsubscriptnorm𝑅subscript¯𝑓𝑡superscript𝑅1𝐻subscript^𝑓𝒯𝐻𝐹212\displaystyle=\left(\int_{P_{t}^{h}}\norm{R\circ\bar{f}_{t}\circ R^{-1}(H)-% \widehat{f}_{\mathcal{T}}(H)}_{F}^{2}\differential h\right)^{1/2}= ( ∫ start_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_ARG italic_R ∘ over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∘ italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_H ) - over^ start_ARG italic_f end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_H ) end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_h ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
    ϵ.absentitalic-ϵ\displaystyle\leq\epsilon.≤ italic_ϵ . (F.3)

    Along with Step 1, we obtain

    q(h¯,t)f^(h¯)L2(Pth)q(h¯,t)f¯(h¯,t)L2(Pth)+f¯(h¯,t)f^(h¯)L2(Pth)(1+d0)ϵ.subscriptnorm𝑞¯𝑡^𝑓¯superscript𝐿2superscriptsubscript𝑃𝑡subscriptnorm𝑞¯𝑡¯𝑓¯𝑡superscript𝐿2superscriptsubscript𝑃𝑡subscriptnorm¯𝑓¯𝑡^𝑓¯superscript𝐿2superscriptsubscript𝑃𝑡1subscript𝑑0italic-ϵ\displaystyle\norm{q(\bar{h},t)-\widehat{f}(\bar{h})}_{L^{2}(P_{t}^{h})}\leq% \norm{q(\bar{h},t)-\bar{f}(\bar{h},t)}_{L^{2}(P_{t}^{h})}+\norm{\bar{f}(\bar{h% },t)-\widehat{f}(\bar{h})}_{L^{2}(P_{t}^{h})}\leq(1+\sqrt{d_{0}})\epsilon.∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) - over^ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ≤ ∥ start_ARG italic_q ( over¯ start_ARG italic_h end_ARG , italic_t ) - over¯ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT + ∥ start_ARG over¯ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG , italic_t ) - over^ start_ARG italic_f end_ARG ( over¯ start_ARG italic_h end_ARG ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ≤ ( 1 + square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) italic_ϵ .

    The constructed approximator to logpt(x)subscript𝑝𝑡𝑥\nabla\log p_{t}(x)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) is sW^=(Bf^(Bx,t)x)/σ(t)subscript𝑠^𝑊𝐵^𝑓superscript𝐵top𝑥𝑡𝑥𝜎𝑡s_{\widehat{W}}=(B\widehat{f}(B^{\top}x,t)-x)/\sigma(t)italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT = ( italic_B over^ start_ARG italic_f end_ARG ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) - italic_x ) / italic_σ ( italic_t ), whose approximation error is

    logpt()sW^(,t)L2(Pt)1+d0σ(t)ϵ,t[T0,T].formulae-sequencesubscriptnormsubscript𝑝𝑡subscript𝑠^𝑊𝑡superscript𝐿2subscript𝑃𝑡1subscript𝑑0𝜎𝑡italic-ϵfor-all𝑡subscript𝑇0𝑇\displaystyle\norm{\nabla\log p_{t}(\cdot)-s_{\widehat{W}}(\cdot,t)}_{L^{2}(P_% {t})}\leq\frac{1+\sqrt{d_{0}}}{\sigma(t)}\epsilon,\quad\forall t\in[T_{0},T].∥ start_ARG ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) - italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ( ⋅ , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ≤ divide start_ARG 1 + square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG italic_σ ( italic_t ) end_ARG italic_ϵ , ∀ italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] .
  • Settling-down of Hyperparameters. We settle down the hyperparameters to configure our network here. We refer to Section E.2 for some of the following calculations.

Then we have

CF2,superscriptsubscript𝐶𝐹2\displaystyle C_{F}^{2,\infty}italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT =𝒪(i=0d1δ2i)=𝒪(δd)absent𝒪superscriptsubscript𝑖0𝑑1superscript𝛿2𝑖𝒪superscript𝛿𝑑\displaystyle=\mathcal{O}\left(\sqrt{\sum_{i=0}^{d-1}\delta^{-2i}}\right)=% \mathcal{O}\left(\delta^{-d}\right)= caligraphic_O ( square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUPERSCRIPT - 2 italic_i end_POSTSUPERSCRIPT end_ARG ) = caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ) (F.12)
=(1/ϵ)𝒪(1).absentsuperscript1italic-ϵ𝒪1\displaystyle=(1/\epsilon)^{\mathcal{O}(1)}.= ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT . (By setting δ=𝒪(ϵ2/d)𝛿𝒪superscriptitalic-ϵ2𝑑\delta=\mathcal{O}(\epsilon^{2/d})italic_δ = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT 2 / italic_d end_POSTSUPERSCRIPT ) according to Section E.4)

and

CFsubscript𝐶𝐹\displaystyle C_{F}italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT =supx2=1W1x2=𝒪(δd)absentsubscriptsupremumsubscriptnorm𝑥21subscriptnormsubscript𝑊1𝑥2𝒪superscript𝛿𝑑\displaystyle=\sup_{\norm{x}_{2}=1}\norm{W_{1}x}_{2}=\mathcal{O}\left(\delta^{% -d}\right)= roman_sup start_POSTSUBSCRIPT ∥ start_ARG italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT ∥ start_ARG italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = caligraphic_O ( italic_δ start_POSTSUPERSCRIPT - italic_d end_POSTSUPERSCRIPT ) (F.13)
=(1/ϵ)𝒪(1).absentsuperscript1italic-ϵ𝒪1\displaystyle=(1/\epsilon)^{\mathcal{O}(1)}.= ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT caligraphic_O ( 1 ) end_POSTSUPERSCRIPT . (By setting δ=𝒪(ϵ2/d)𝛿𝒪superscriptitalic-ϵ2𝑑\delta=\mathcal{O}(\epsilon^{2/d})italic_δ = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT 2 / italic_d end_POSTSUPERSCRIPT ) according to Section E.4)

This completes the proof. ∎

F.2 Proof of Corollary 3.1.1

Here we present the auxiliary theoretical results about the covering number of transformer networks in Section F.2.1 to prepare our main proof of Corollary 3.1.1. The results is based on the Theorem A.17 of (Edelman et al., 2022). Then we derive the sample complexity bound of DiTs (i.e., the proof of Corollary 3.1.1) in Section F.2.

F.2.1 Auxiliary Lemmas for Corollary 3.1.1
Lemma F.5 (Lemma 15 of (Chen et al., 2023a)).

Let 𝒢𝒢\mathcal{G}caligraphic_G be a bounded function class, i.e., there exists a constant b𝑏bitalic_b such that any g𝒢:d0[0,b]:𝑔𝒢maps-tosuperscriptsubscript𝑑00𝑏g\in\mathcal{G}:\mathbb{R}^{d_{0}}\mapsto[0,b]italic_g ∈ caligraphic_G : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ↦ [ 0 , italic_b ]. Let z1,z2,,znd0subscript𝑧1subscript𝑧2subscript𝑧𝑛superscriptsubscript𝑑0z_{1},z_{2},\cdots,z_{n}\in\mathbb{R}^{d_{0}}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT be i.i.d. random variables. For any δ(0,1),a1formulae-sequence𝛿01𝑎1\delta\in(0,1),a\leq 1italic_δ ∈ ( 0 , 1 ) , italic_a ≤ 1, and c>0𝑐0c>0italic_c > 0, we have

(supg𝒢1ni=1ng(zi)(1+a)𝔼[g(z)]>(1+3/a)B3nlog𝒩(c,𝒢,)δ+(2+a)c)δ,subscriptsupremum𝑔𝒢1𝑛superscriptsubscript𝑖1𝑛𝑔subscript𝑧𝑖1𝑎𝔼delimited-[]𝑔𝑧13𝑎𝐵3𝑛𝒩𝑐𝒢subscriptnorm𝛿2𝑎𝑐𝛿\displaystyle\mathbb{P}\left(\sup_{g\in\mathcal{G}}\frac{1}{n}\sum_{i=1}^{n}g(% z_{i})-(1+a)\mathbb{E}\left[g(z)\right]>\frac{(1+3/a)B}{3n}\log\frac{\mathcal{% N}(c,\mathcal{G},\norm{\cdot}_{\infty})}{\delta}+(2+a)c\right)\leq\delta,blackboard_P ( roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_g ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - ( 1 + italic_a ) blackboard_E [ italic_g ( italic_z ) ] > divide start_ARG ( 1 + 3 / italic_a ) italic_B end_ARG start_ARG 3 italic_n end_ARG roman_log divide start_ARG caligraphic_N ( italic_c , caligraphic_G , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ) end_ARG start_ARG italic_δ end_ARG + ( 2 + italic_a ) italic_c ) ≤ italic_δ ,
(supg𝒢𝔼[g(z)]1+ani=1ng(zi)>(1+6/a)B3nlog𝒩(c,𝒢,)δ+(2+a)c)δ.subscriptsupremum𝑔𝒢𝔼delimited-[]𝑔𝑧1𝑎𝑛superscriptsubscript𝑖1𝑛𝑔subscript𝑧𝑖16𝑎𝐵3𝑛𝒩𝑐𝒢subscriptnorm𝛿2𝑎𝑐𝛿\displaystyle\mathbb{P}\left(\sup_{g\in\mathcal{G}}\mathbb{E}\left[g(z)\right]% -\frac{1+a}{n}\sum_{i=1}^{n}g(z_{i})>\frac{(1+6/a)B}{3n}\log\frac{\mathcal{N}(% c,\mathcal{G},\norm{\cdot}_{\infty})}{\delta}+(2+a)c\right)\leq\delta.blackboard_P ( roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G end_POSTSUBSCRIPT blackboard_E [ italic_g ( italic_z ) ] - divide start_ARG 1 + italic_a end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_g ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) > divide start_ARG ( 1 + 6 / italic_a ) italic_B end_ARG start_ARG 3 italic_n end_ARG roman_log divide start_ARG caligraphic_N ( italic_c , caligraphic_G , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ) end_ARG start_ARG italic_δ end_ARG + ( 2 + italic_a ) italic_c ) ≤ italic_δ .

Now, we give the definition of covering number as the follows.

Definition F.1 (Covering Number).

Given a function class \mathcal{F}caligraphic_F and a data distribution P𝑃Pitalic_P. Sample n data points {Xi}i=1nsuperscriptsubscriptsubscript𝑋𝑖𝑖1𝑛\{X_{i}\}_{i=1}^{n}{ italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT from P𝑃Pitalic_P, then the covering number 𝒩(ϵ,,{Xi}i=1n,)𝒩italic-ϵsuperscriptsubscriptsubscript𝑋𝑖𝑖1𝑛norm\mathcal{N}(\epsilon,\mathcal{F},\{X_{i}\}_{i=1}^{n},\norm{\cdot})caligraphic_N ( italic_ϵ , caligraphic_F , { italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ ) is the smallest size of a collection (a cover) 𝒞𝒞\mathcal{C}\in\mathcal{F}caligraphic_C ∈ caligraphic_F such that for any f𝑓f\in\mathcal{F}italic_f ∈ caligraphic_F, there exist f^𝒞^𝑓𝒞\widehat{f}\in\mathcal{C}over^ start_ARG italic_f end_ARG ∈ caligraphic_C satisfying

maxif(Xi)f^(Xi)ϵ.subscript𝑖norm𝑓subscript𝑋𝑖^𝑓subscript𝑋𝑖italic-ϵ\displaystyle\max_{i}\norm{f(X_{i})-\widehat{f}(X_{i})}\leq\epsilon.roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_ARG italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - over^ start_ARG italic_f end_ARG ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG ∥ ≤ italic_ϵ .

Further, we define the covering number with respect to the data distribution as

𝒩(ϵ,,)=sup{Xi}i=1nP𝒩(ϵ,,{Xi}i=1n,).𝒩italic-ϵnormsubscriptsupremumsimilar-tosuperscriptsubscriptsubscript𝑋𝑖𝑖1𝑛𝑃𝒩italic-ϵsuperscriptsubscriptsubscript𝑋𝑖𝑖1𝑛norm\displaystyle\mathcal{N}(\epsilon,\mathcal{F},\norm{\cdot})=\sup_{\{X_{i}\}_{i% =1}^{n}\sim P}\mathcal{N}(\epsilon,\mathcal{F},\{X_{i}\}_{i=1}^{n},\norm{\cdot% }).caligraphic_N ( italic_ϵ , caligraphic_F , ∥ start_ARG ⋅ end_ARG ∥ ) = roman_sup start_POSTSUBSCRIPT { italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∼ italic_P end_POSTSUBSCRIPT caligraphic_N ( italic_ϵ , caligraphic_F , { italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ ) .

Then we give the covering number of the transformer networks.

Lemma F.6 (Modified from Theorem A.17 of (Edelman et al., 2022)).

Let 𝒯pr,m,l(K,C𝒯,COV2,,COV,CKQ2,,CKQ,CF2,,CF,CE,L𝒯)superscriptsubscript𝒯𝑝𝑟𝑚𝑙𝐾subscript𝐶𝒯superscriptsubscript𝐶𝑂𝑉2subscript𝐶𝑂𝑉superscriptsubscript𝐶𝐾𝑄2subscript𝐶𝐾𝑄superscriptsubscript𝐶𝐹2subscript𝐶𝐹subscript𝐶𝐸subscript𝐿𝒯\mathcal{T}_{p}^{r,m,l}(K,C_{\mathcal{T}},C_{OV}^{2,\infty},C_{OV},C_{KQ}^{2,% \infty},C_{KQ},C_{F}^{2,\infty},C_{F},C_{E},L_{\mathcal{T}})caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT ( italic_K , italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) represent the class of functions of K𝐾Kitalic_K-layer transformer blocks satisfying the norm bound for matrix and Lipsichitz property for feed-forward layers. Then for all data point X2,CXsubscriptnorm𝑋2subscript𝐶𝑋\norm{X}_{2,\infty}\leq C_{X}∥ start_ARG italic_X end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT we have

log𝒩(ϵc,𝒯pr,m,l(K,C𝒯,COV2,,COV,CKQ2,,CKQ,CF2,,CF,CE,L𝒯),2)𝒩subscriptitalic-ϵ𝑐superscriptsubscript𝒯𝑝𝑟𝑚𝑙𝐾subscript𝐶𝒯superscriptsubscript𝐶𝑂𝑉2subscript𝐶𝑂𝑉superscriptsubscript𝐶𝐾𝑄2subscript𝐶𝐾𝑄superscriptsubscript𝐶𝐹2subscript𝐶𝐹subscript𝐶𝐸subscript𝐿𝒯subscriptnorm2\displaystyle\log\mathcal{N}(\epsilon_{c},\mathcal{T}_{p}^{r,m,l}(K,C_{% \mathcal{T}},C_{OV}^{2,\infty},C_{OV},C_{KQ}^{2,\infty},C_{KQ},C_{F}^{2,\infty% },C_{F},C_{E},L_{\mathcal{T}}),\norm{\cdot}_{2})roman_log caligraphic_N ( italic_ϵ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r , italic_m , italic_l end_POSTSUPERSCRIPT ( italic_K , italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )
\displaystyle\leq log(nL)ϵc2(i=1Kα23(d23(CF2,)43+d23(2(CF)2COVCKQ2,)23+τm23((CF)2COV2,)23))3,𝑛𝐿superscriptsubscriptitalic-ϵ𝑐2superscriptsuperscriptsubscript𝑖1𝐾superscript𝛼23superscript𝑑23superscriptsuperscriptsubscript𝐶𝐹243superscript𝑑23superscript2superscriptsubscript𝐶𝐹2subscript𝐶𝑂𝑉superscriptsubscript𝐶𝐾𝑄223𝜏superscript𝑚23superscriptsuperscriptsubscript𝐶𝐹2superscriptsubscript𝐶𝑂𝑉2233\displaystyle\frac{\log(nL)}{\epsilon_{c}^{2}}\cdot\left(\sum_{i=1}^{K}\alpha^% {\frac{2}{3}}\left(d^{\frac{2}{3}}\left(C_{F}^{2,\infty}\right)^{\frac{4}{3}}+% d^{\frac{2}{3}}\left(2(C_{F})^{2}C_{OV}C_{KQ}^{2,\infty}\right)^{\frac{2}{3}}+% \tau m^{\frac{2}{3}}\left((C_{F})^{2}C_{OV}^{2,\infty}\right)^{\frac{2}{3}}% \right)\right)^{3},divide start_ARG roman_log ( start_ARG italic_n italic_L end_ARG ) end_ARG start_ARG italic_ϵ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ⋅ ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( italic_d start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 4 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT + italic_d start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( 2 ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT + italic_τ italic_m start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ,

where αj<i(CF)2COV(1+4CKQ)(CX+CE)𝛼subscriptproduct𝑗𝑖superscriptsubscript𝐶𝐹2subscript𝐶𝑂𝑉14subscript𝐶𝐾𝑄subscript𝐶𝑋subscript𝐶𝐸\alpha\coloneqq\prod_{j<i}(C_{F})^{2}C_{OV}(1+4C_{KQ})(C_{X}+C_{E})italic_α ≔ ∏ start_POSTSUBSCRIPT italic_j < italic_i end_POSTSUBSCRIPT ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ( 1 + 4 italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT ) ( italic_C start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ).

Remark F.1.

We modify (Edelman et al., 2022, Theorem A.17) in seven aspects:

  1. 1.

    We do not consider the last linear layer in the model: converting each column vector of the Transformer output to a scalar. Therefore, we ignore the item related to the last linear layer in (Edelman et al., 2022, Theorem A.17).

  2. 2.

    We do not consider the normalization layer in our model. Because the normalization layer in the original proof of only applies norm(X1)norm(X2)2,X1X22,subscriptnormsubscriptproductnormsubscript𝑋1subscriptproductnormsubscript𝑋22subscriptnormsubscript𝑋1subscript𝑋22\norm{\prod_{\rm norm}(X_{1})-\prod_{\rm norm}(X_{2})}_{2,\infty}\leq\norm{X_{% 1}-X_{2}}_{2,\infty}∥ start_ARG ∏ start_POSTSUBSCRIPT roman_norm end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - ∏ start_POSTSUBSCRIPT roman_norm end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ≤ ∥ start_ARG italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT, ignoring this layer does not change the result.

  3. 3.

    Our activation function is ReLUReLU{\rm ReLU}roman_ReLU, we replace the Lipschitz upperbound of activate function by 1.

  4. 4.

    We consider the positional encoding (E.4) in our work, we need to replace the upperbound CXsubscript𝐶𝑋C_{X}italic_C start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT for the inputs with the upperbound CX+CEsubscript𝐶𝑋subscript𝐶𝐸C_{X}+C_{E}italic_C start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT. Besides, for multi-layer Transformer, the original conclusion in (Edelman et al., 2022, Theorem A.17) considers the upperbound for the 2,22,\infty2 , ∞-norm of inputs is 1, we add the upperbound for the inputs in Lemma F.6.

  5. 5.

    We use (2.7) as the feed forward layer, including two linear layers and a residual layer. Thus, in Lemma F.6, we replace the original upperbound for the norm of weight matrix with the upperbound for the norm of Id+W2W1subscript𝐼𝑑subscript𝑊2subscript𝑊1I_{d}+W_{2}W_{1}italic_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT + italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. In the following, we use 𝒪𝒪\mathcal{O}caligraphic_O to estimate the log-covering number, thus we ignore the item for Idsubscript𝐼𝑑I_{d}italic_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT here for converience. This is the same for the self-attention layer.

  6. 6.

    We use multi-head attention, we add the number of heads τ𝜏\tauitalic_τ in our result, similar to (Edelman et al., 2022, Theorem A.12).

  7. 7.

    In our work, we use Transformer 𝒯p2,1,4superscriptsubscript𝒯𝑝214\mathcal{T}_{p}^{2,1,4}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT, i.e., τ=2,m=1formulae-sequence𝜏2𝑚1\tau=2,m=1italic_τ = 2 , italic_m = 1.

F.2.2 Proof of Corollary 3.1.1
Proof of Corollary 3.1.1.

Our proof is built on (Chen et al., 2023a, Appendix B.2). Firstly, for one data sample, we define the empirical score matching loss objective (2.1) as follows

(x;sW^)=1TT0T0T𝔼xt|x0=x[xtlogψt(xt|x0)sW^(xt,t)22]dt.𝑥subscript𝑠^𝑊1𝑇subscript𝑇0superscriptsubscriptsubscript𝑇0𝑇subscript𝔼conditionalsubscript𝑥𝑡subscript𝑥0𝑥delimited-[]superscriptsubscriptnormsubscriptsubscript𝑥𝑡subscript𝜓𝑡conditionalsubscript𝑥𝑡subscript𝑥0subscript𝑠^𝑊subscript𝑥𝑡𝑡22𝑡\displaystyle\ell(x;s_{\widehat{W}})=\frac{1}{T-T_{0}}\int_{T_{0}}^{T}\mathbb{% E}_{x_{t}|x_{0}=x}[\norm{\nabla_{x_{t}}\log\psi_{t}(x_{t}|x_{0})-s_{\widehat{W% }}(x_{t},t)}_{2}^{2}]\differential t.roman_ℓ ( italic_x ; italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_x end_POSTSUBSCRIPT [ ∥ start_ARG ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] start_DIFFOP roman_d end_DIFFOP italic_t .

Then we define (sW^)=𝔼xP0[(x;sW^)]subscript𝑠^𝑊subscript𝔼similar-to𝑥subscript𝑃0delimited-[]𝑥subscript𝑠^𝑊\mathcal{L}(s_{\widehat{W}})=\mathbb{E}_{x\sim P_{0}}\left[\ell(x;s_{\widehat{% W}})\right]caligraphic_L ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_x ; italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) ].

Following (Chen et al., 2023a, Appendix B.2), for any a(0,1)𝑎01a\in(0,1)italic_a ∈ ( 0 , 1 ), we have

(sW^)trunc(sW^)(1+a)^trunc(sW^)(I)+(sW^)trunc(sW^)(II)+(1+a)infsW𝒮NN^(sW)(III).subscript𝑠^𝑊subscriptsuperscripttruncsubscript𝑠^𝑊1𝑎superscript^truncsubscript𝑠^𝑊𝐼subscriptsubscript𝑠^𝑊superscripttruncsubscript𝑠^𝑊𝐼𝐼1𝑎subscriptsubscriptinfimumsubscript𝑠𝑊subscript𝒮NN^subscript𝑠𝑊𝐼𝐼𝐼\displaystyle\mathcal{L}(s_{\widehat{W}})\leq\underbrace{\mathcal{L}^{\rm trunc% }(s_{\widehat{W}})-(1+a)\widehat{\mathcal{L}}^{\rm trunc}(s_{\widehat{W}})}_{(% I)}+\underbrace{\mathcal{L}(s_{\widehat{W}})-\mathcal{L}^{\rm trunc}(s_{% \widehat{W}})}_{(II)}+(1+a)\underbrace{\inf_{s_{W}\in\mathcal{S}_{\rm NN}}% \widehat{\mathcal{L}}(s_{W})}_{(III)}.caligraphic_L ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) ≤ under⏟ start_ARG caligraphic_L start_POSTSUPERSCRIPT roman_trunc end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) - ( 1 + italic_a ) over^ start_ARG caligraphic_L end_ARG start_POSTSUPERSCRIPT roman_trunc end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT ( italic_I ) end_POSTSUBSCRIPT + under⏟ start_ARG caligraphic_L ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) - caligraphic_L start_POSTSUPERSCRIPT roman_trunc end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT ( italic_I italic_I ) end_POSTSUBSCRIPT + ( 1 + italic_a ) under⏟ start_ARG roman_inf start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ∈ caligraphic_S start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG ( italic_s start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT ( italic_I italic_I italic_I ) end_POSTSUBSCRIPT .

where

trunc(sW^)𝔼xP0[trunc(x;sW^)]=𝔼xP0[(x;sW^)𝟙{x2rx}],rx>B.formulae-sequencesuperscripttruncsubscript𝑠^𝑊subscript𝔼similar-to𝑥subscript𝑃0delimited-[]superscripttrunc𝑥subscript𝑠^𝑊subscript𝔼similar-to𝑥subscript𝑃0delimited-[]𝑥subscript𝑠^𝑊1subscriptnorm𝑥2subscript𝑟𝑥subscript𝑟𝑥𝐵\displaystyle\mathcal{L}^{\rm trunc}(s_{\widehat{W}})\coloneqq\mathbb{E}_{x% \sim P_{0}}\left[\ell^{\rm trunc}(x;s_{\widehat{W}})\right]=\mathbb{E}_{x\sim P% _{0}}\left[\ell(x;s_{\widehat{W}})\mathds{1}\{\norm{x}_{2}\leq r_{x}\}\right],% \leavevmode\nobreak\ r_{x}>B.caligraphic_L start_POSTSUPERSCRIPT roman_trunc end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) ≔ blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ start_POSTSUPERSCRIPT roman_trunc end_POSTSUPERSCRIPT ( italic_x ; italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) ] = blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_x ; italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ) blackboard_1 { ∥ start_ARG italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT } ] , italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT > italic_B .

We denote

η𝜂\displaystyle\leavevmode\nobreak\ \etaitalic_η 4C𝒯(C𝒯+rx)(rx/D)D2exp(rx2/σ(t))/(T0(TT0)),absent4subscript𝐶𝒯subscript𝐶𝒯subscript𝑟𝑥superscriptsubscript𝑟𝑥𝐷𝐷2superscriptsubscript𝑟𝑥2𝜎𝑡subscript𝑇0𝑇subscript𝑇0\displaystyle\coloneqq 4C_{\mathcal{T}}(C_{\mathcal{T}}+r_{x})(r_{x}/D)^{D-2}% \exp(-r_{x}^{2}/\sigma(t))/(T_{0}(T-T_{0})),≔ 4 italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) ( italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT / italic_D ) start_POSTSUPERSCRIPT italic_D - 2 end_POSTSUPERSCRIPT roman_exp ( start_ARG - italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_σ ( italic_t ) end_ARG ) / ( italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ,
rxsubscript𝑟𝑥\displaystyle\leavevmode\nobreak\ r_{x}italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT 𝒪(d0logd0+logC𝒯+log(n/δ¯)).absent𝒪subscript𝑑0subscript𝑑0subscript𝐶𝒯𝑛¯𝛿\displaystyle\coloneqq\mathcal{O}\left(\sqrt{d_{0}\log d_{0}+\log C_{\mathcal{% T}}+\log(n/\bar{\delta})}\right).≔ caligraphic_O ( square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + roman_log italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + roman_log ( start_ARG italic_n / over¯ start_ARG italic_δ end_ARG end_ARG ) end_ARG ) .

For any δ¯>0¯𝛿0\bar{\delta}>0over¯ start_ARG italic_δ end_ARG > 0, following (Chen et al., 2023a, Appendix B.2), we have the following for term (I)𝐼(I)( italic_I ) with probability 1δ¯1¯𝛿1-\bar{\delta}1 - over¯ start_ARG italic_δ end_ARG,

(I)=𝒪((1+3/a)(C𝒯2+rx2)nT0(TT0)log𝒩((TT0)(ιη)(C𝒯+rx)log(T/T0),𝒮𝒯p2,1,4,2)δ¯+(2+a)c).𝐼𝒪13𝑎superscriptsubscript𝐶𝒯2superscriptsubscript𝑟𝑥2𝑛subscript𝑇0𝑇subscript𝑇0𝒩𝑇subscript𝑇0𝜄𝜂subscript𝐶𝒯subscript𝑟𝑥𝑇subscript𝑇0subscript𝒮superscriptsubscript𝒯𝑝214subscriptnorm2¯𝛿2𝑎𝑐\displaystyle(I)=\mathcal{O}\left(\frac{(1+3/a)(C_{\mathcal{T}}^{2}+r_{x}^{2})% }{nT_{0}(T-T_{0})}\log\frac{\mathcal{N}\left(\frac{(T-T_{0})(\iota-\eta)}{(C_{% \mathcal{T}}+r_{x})\log(T/T_{0})},\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}},\norm{% \cdot}_{2}\right)}{\bar{\delta}}+(2+a)c\right).( italic_I ) = caligraphic_O ( divide start_ARG ( 1 + 3 / italic_a ) ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_n italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG roman_log divide start_ARG caligraphic_N ( divide start_ARG ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ( italic_ι - italic_η ) end_ARG start_ARG ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) roman_log ( start_ARG italic_T / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) end_ARG , caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG start_ARG over¯ start_ARG italic_δ end_ARG end_ARG + ( 2 + italic_a ) italic_c ) .

where c0𝑐0c\leq 0italic_c ≤ 0 is a constant, and ι>0𝜄0\iota>0italic_ι > 0 will be determined later.

We set ι=1/(n1/4T0(TT0))𝜄1superscript𝑛14subscript𝑇0𝑇subscript𝑇0\iota=1/(n^{1/4}T_{0}(T-T_{0}))italic_ι = 1 / ( italic_n start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ), then we have

(I)=𝒪((1+3/a)(C𝒯2+rx2)nT0(TT0)log𝒩((n(C𝒯+rx)T0log(T/T0))1,𝒮𝒯p2,1,4,2)δ¯+1n),𝐼𝒪13𝑎superscriptsubscript𝐶𝒯2superscriptsubscript𝑟𝑥2𝑛subscript𝑇0𝑇subscript𝑇0𝒩superscript𝑛subscript𝐶𝒯subscript𝑟𝑥subscript𝑇0𝑇subscript𝑇01subscript𝒮superscriptsubscript𝒯𝑝214subscriptnorm2¯𝛿1𝑛\displaystyle(I)=\mathcal{O}\left(\frac{(1+3/a)\left(C_{\mathcal{T}}^{2}+r_{x}% ^{2}\right)}{nT_{0}(T-T_{0})}\log\frac{\mathcal{N}\left((n(C_{\mathcal{T}}+r_{% x})T_{0}\log(T/T_{0}))^{-1},\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}},\norm{\cdot}% _{2}\right)}{\bar{\delta}}+\frac{1}{n}\right),( italic_I ) = caligraphic_O ( divide start_ARG ( 1 + 3 / italic_a ) ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_n italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG roman_log divide start_ARG caligraphic_N ( ( italic_n ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_T / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG start_ARG over¯ start_ARG italic_δ end_ARG end_ARG + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ) ,

with probability 1δ¯1¯𝛿1-\bar{\delta}1 - over¯ start_ARG italic_δ end_ARG.

Following the upper bound of other two terms and the proof details in (Chen et al., 2023a, Appendix B.2), we have

1TT0T0TsW^(,t)logpt()L2(Pt)2dt1𝑇subscript𝑇0superscriptsubscriptsubscript𝑇0𝑇superscriptsubscriptnormsubscript𝑠^𝑊𝑡subscript𝑝𝑡superscript𝐿2subscript𝑃𝑡2𝑡\displaystyle\leavevmode\nobreak\ \frac{1}{T-T_{0}}\int_{T_{0}}^{T}\norm{s_{% \widehat{W}}(\cdot,t)-\nabla\log p_{t}(\cdot)}_{L^{2}(P_{t})}^{2}\differential tdivide start_ARG 1 end_ARG start_ARG italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ start_ARG italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ( ⋅ , italic_t ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_t
=\displaystyle== 𝒪((C𝒯2+rx2)ϵ2nT0(TT0)log𝒩((n(C𝒯+rx)T0log(T/T0))1,𝒮𝒯p2,1,4,2)δ¯+1n+d02T0(TT0)ϵ2),𝒪superscriptsubscript𝐶𝒯2superscriptsubscript𝑟𝑥2superscriptitalic-ϵ2𝑛subscript𝑇0𝑇subscript𝑇0𝒩superscript𝑛subscript𝐶𝒯subscript𝑟𝑥subscript𝑇0𝑇subscript𝑇01subscript𝒮superscriptsubscript𝒯𝑝214subscriptnorm2¯𝛿1𝑛superscriptsubscript𝑑02subscript𝑇0𝑇subscript𝑇0superscriptitalic-ϵ2\displaystyle\leavevmode\nobreak\ \mathcal{O}\left(\frac{\left(C_{\mathcal{T}}% ^{2}+r_{x}^{2}\right)}{\epsilon^{2}nT_{0}(T-T_{0})}\log\frac{\mathcal{N}\left(% (n(C_{\mathcal{T}}+r_{x})T_{0}\log(T/T_{0}))^{-1},\mathcal{S}_{\mathcal{T}_{p}% ^{2,1,4}},\norm{\cdot}_{2}\right)}{\bar{\delta}}+\frac{1}{n}+\frac{d_{0}^{2}}{% T_{0}(T-T_{0})}\epsilon^{2}\right),caligraphic_O ( divide start_ARG ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_n italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG roman_log divide start_ARG caligraphic_N ( ( italic_n ( italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_T / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG start_ARG over¯ start_ARG italic_δ end_ARG end_ARG + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG + divide start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) , (F.14)

with probability 13δ¯13¯𝛿1-3\bar{\delta}1 - 3 over¯ start_ARG italic_δ end_ARG.

Covering Number of 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.

Next step is to calculate the covering number of 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT consists of two components: (i) Matrix WBsubscript𝑊𝐵W_{B}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT with orthonormal columns; (ii) Network function f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT. Suppose we have WB1,WB2subscript𝑊𝐵1subscript𝑊𝐵2W_{B1},W_{B2}italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT and f1,f2subscript𝑓1subscript𝑓2f_{1},f_{2}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT such that WB1WB2Fδ1subscriptnormsubscript𝑊𝐵1subscript𝑊𝐵2𝐹subscript𝛿1\norm{W_{B1}-W_{B2}}_{F}\leq\delta_{1}∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT - italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and supx23rx+DlogD,t[T0,T]f1(x,t)f2(x,t)2δ2subscriptsupremumformulae-sequencesubscriptnorm𝑥23subscript𝑟𝑥𝐷𝐷𝑡subscript𝑇0𝑇subscriptnormsubscript𝑓1𝑥𝑡subscript𝑓2𝑥𝑡2subscript𝛿2\sup_{\norm{x}_{2}\leq 3r_{x}+\sqrt{D\log D},t\in[T_{0},T]}\norm{f_{1}(x,t)-f_% {2}(x,t)}_{2}\leq\delta_{2}roman_sup start_POSTSUBSCRIPT ∥ start_ARG italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG , italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT ∥ start_ARG italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x , italic_t ) - italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, where f1=R1f𝒯1R,f2=R1f𝒯2Rformulae-sequencesubscript𝑓1superscript𝑅1subscript𝑓𝒯1𝑅subscript𝑓2superscript𝑅1subscript𝑓𝒯2𝑅f_{1}=R^{-1}\circ f_{\mathcal{T}1}\circ R,f_{2}=R^{-1}\circ f_{\mathcal{T}2}\circ Ritalic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT caligraphic_T 1 end_POSTSUBSCRIPT ∘ italic_R , italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_R start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT caligraphic_T 2 end_POSTSUBSCRIPT ∘ italic_R. Then we evaluate

supx23rx+DlogD,t[T0,T]sWB1,f𝒯1(x,t)sWB2,f𝒯2(x,t)2subscriptsupremumformulae-sequencesubscriptnorm𝑥23subscript𝑟𝑥𝐷𝐷𝑡subscript𝑇0𝑇subscriptnormsubscript𝑠subscript𝑊𝐵1subscript𝑓𝒯1𝑥𝑡subscript𝑠subscript𝑊𝐵2subscript𝑓𝒯2𝑥𝑡2\displaystyle\quad\sup_{\norm{x}_{2}\leq 3r_{x}+\sqrt{D\log D},t\in[T_{0},T]}% \norm{s_{W_{B1},f_{\mathcal{T}1}}(x,t)-s_{W_{B2},f_{\mathcal{T}2}}(x,t)}_{2}roman_sup start_POSTSUBSCRIPT ∥ start_ARG italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG , italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT ∥ start_ARG italic_s start_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT caligraphic_T 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_t ) - italic_s start_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT caligraphic_T 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
=1σ(t)supx23rx+DlogD,t[T0,T]WB1f1(WB1x,t)WB2f2(WB2x,t)2absent1𝜎𝑡subscriptsupremumformulae-sequencesubscriptnorm𝑥23subscript𝑟𝑥𝐷𝐷𝑡subscript𝑇0𝑇subscriptnormsubscript𝑊𝐵1subscript𝑓1superscriptsubscript𝑊𝐵1top𝑥𝑡subscript𝑊𝐵2subscript𝑓2superscriptsubscript𝑊𝐵2top𝑥𝑡2\displaystyle=\frac{1}{\sigma(t)}\sup_{\norm{x}_{2}\leq 3r_{x}+\sqrt{D\log D},% t\in[T_{0},T]}\norm{W_{B1}f_{1}(W_{B1}^{\top}x,t)-W_{B2}f_{2}(W_{B2}^{\top}x,t% )}_{2}= divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG roman_sup start_POSTSUBSCRIPT ∥ start_ARG italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG , italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) - italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
1σ(t)supx23rx+DlogD,t[T0,T](WB1f1(WB1x,t)WB1f1(WB2x,t)2\displaystyle\leq\frac{1}{\sigma(t)}\sup_{\norm{x}_{2}\leq 3r_{x}+\sqrt{D\log D% },t\in[T_{0},T]}\Bigg{(}\norm{W_{B1}f_{1}(W_{B1}^{\top}x,t)-W_{B1}f_{1}(W_{B2}% ^{\top}x,t)}_{2}≤ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG roman_sup start_POSTSUBSCRIPT ∥ start_ARG italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG , italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT ( ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) - italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
+WB1f1(WB2x,t)WB1f2(WB2x,t)2+WB1f2(WB2x,t)WB2f2(WB2x,t)2)\displaystyle\quad+\norm{W_{B1}f_{1}(W_{B2}^{\top}x,t)-W_{B1}f_{2}(W_{B2}^{% \top}x,t)}_{2}+\norm{W_{B1}f_{2}(W_{B2}^{\top}x,t)-W_{B2}f_{2}(W_{B2}^{\top}x,% t)}_{2}\Bigg{)}+ ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) - italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B 1 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) - italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_B 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )
1σ(t)(L𝒯δ1d0(3rx+DlogD)+δ2+δ1K),absent1𝜎𝑡subscript𝐿𝒯subscript𝛿1subscript𝑑03subscript𝑟𝑥𝐷𝐷subscript𝛿2subscript𝛿1𝐾\displaystyle\leq\frac{1}{\sigma(t)}\left(L_{\mathcal{T}}\delta_{1}\sqrt{d_{0}% }(3r_{x}+\sqrt{D\log D})+\delta_{2}+\delta_{1}K\right),≤ divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_t ) end_ARG ( italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ( 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG ) + italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_K ) , (F.15)

where L𝒯subscript𝐿𝒯L_{\mathcal{T}}italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT upper bounds the Lipschitz constant of f𝒯subscript𝑓𝒯f_{\mathcal{T}}italic_f start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT.

For set {WBD×d0:WB21}conditional-setsubscript𝑊𝐵superscript𝐷subscript𝑑0subscriptnormsubscript𝑊𝐵21\{W_{B}\in\mathbb{R}^{D\times d_{0}}:\norm{W_{B}}_{\rm 2}\leq 1\}{ italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT : ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ 1 }, its δ1subscript𝛿1\delta_{1}italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT-covering number is (1+2d0/δ1)Dd0superscript12subscript𝑑0subscript𝛿1𝐷subscript𝑑0\left(1+2\sqrt{d_{0}}/\delta_{1}\right)^{Dd_{0}}( 1 + 2 square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG / italic_δ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_D italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ((Chen et al., 2020a, Lemma 8)). The δ2subscript𝛿2\delta_{2}italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-covering number of f𝑓fitalic_f needs a further discussion as there is a resha** process in our network. For the input reshaped from h¯d0¯superscriptsubscript𝑑0\bar{h}\in\mathbb{R}^{d_{0}}over¯ start_ARG italic_h end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT to Hd×L𝐻superscript𝑑𝐿H\in\mathbb{R}^{d\times L}italic_H ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, we have

h¯2rxHFrx,subscriptnorm¯2subscript𝑟𝑥subscriptnorm𝐻𝐹subscript𝑟𝑥\displaystyle\norm{\bar{h}}_{2}\leq r_{x}\Longleftrightarrow\norm{H}_{F}\leq r% _{x},∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⟺ ∥ start_ARG italic_H end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ,
suph¯23rx+DlogD,t[T0,T]f1(h¯,t)f2(h¯,t)2δ2,subscriptsupremumformulae-sequencesubscriptnorm¯23subscript𝑟𝑥𝐷𝐷𝑡subscript𝑇0𝑇subscriptnormsubscript𝑓1¯𝑡subscript𝑓2¯𝑡2subscript𝛿2\displaystyle\leavevmode\nobreak\ \sup_{\norm{\bar{h}}_{2}\leq 3r_{x}+\sqrt{D% \log D},t\in[T_{0},T]}\norm{f_{1}(\bar{h},t)-f_{2}(\bar{h},t)}_{2}\leq\delta_{% 2},roman_sup start_POSTSUBSCRIPT ∥ start_ARG over¯ start_ARG italic_h end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG , italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT ∥ start_ARG italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) - italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( over¯ start_ARG italic_h end_ARG , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ,

and

\displaystyle\Longleftrightarrow supHF3rx+DlogD,t[T0,T]f𝒯1(H)f𝒯2(H)2δ2.subscriptsupremumformulae-sequencesubscriptnorm𝐻𝐹3subscript𝑟𝑥𝐷𝐷𝑡subscript𝑇0𝑇subscriptnormsubscript𝑓𝒯1𝐻subscript𝑓𝒯2𝐻2subscript𝛿2\displaystyle\leavevmode\nobreak\ \sup_{\norm{H}_{F}\leq 3r_{x}+\sqrt{D\log D}% ,t\in[T_{0},T]}\norm{f_{\mathcal{T}1}(H)-f_{\mathcal{T}2}(H)}_{2}\leq\delta_{2}.roman_sup start_POSTSUBSCRIPT ∥ start_ARG italic_H end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG , italic_t ∈ [ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_T ] end_POSTSUBSCRIPT ∥ start_ARG italic_f start_POSTSUBSCRIPT caligraphic_T 1 end_POSTSUBSCRIPT ( italic_H ) - italic_f start_POSTSUBSCRIPT caligraphic_T 2 end_POSTSUBSCRIPT ( italic_H ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

Thus we can follow the covering number property for sequence-to-sequence transformer 𝒯p2,1,4superscriptsubscript𝒯𝑝214\mathcal{T}_{p}^{2,1,4}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT, i.e., Lemma F.6 and get the following δ2subscript𝛿2\delta_{2}italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-covering number

log(nL)δ22(i=1Kαi23(d23(CF2,)43+d23(2(CF)2COVCKQ2,)23+τm23((CF)2COV2,)23))3,𝑛𝐿superscriptsubscript𝛿22superscriptsuperscriptsubscript𝑖1𝐾superscriptsubscript𝛼𝑖23superscript𝑑23superscriptsuperscriptsubscript𝐶𝐹243superscript𝑑23superscript2superscriptsubscript𝐶𝐹2subscript𝐶𝑂𝑉superscriptsubscript𝐶𝐾𝑄223𝜏superscript𝑚23superscriptsuperscriptsubscript𝐶𝐹2superscriptsubscript𝐶𝑂𝑉2233\displaystyle\frac{\log(nL)}{\delta_{2}^{2}}\cdot\left(\sum_{i=1}^{K}\alpha_{i% }^{\frac{2}{3}}\left(d^{\frac{2}{3}}\left(C_{F}^{2,\infty}\right)^{\frac{4}{3}% }+d^{\frac{2}{3}}\left(2(C_{F})^{2}C_{OV}C_{KQ}^{2,\infty}\right)^{\frac{2}{3}% }+\tau m^{\frac{2}{3}}\left((C_{F})^{2}C_{OV}^{2,\infty}\right)^{\frac{2}{3}}% \right)\right)^{3},divide start_ARG roman_log ( start_ARG italic_n italic_L end_ARG ) end_ARG start_ARG italic_δ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ⋅ ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( italic_d start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 4 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT + italic_d start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( 2 ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT + italic_τ italic_m start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ( ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 2 end_ARG start_ARG 3 end_ARG end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ,

where

αij<i(CF)2COV(1+4CKQ)(CX+CE).subscript𝛼𝑖subscriptproduct𝑗𝑖superscriptsubscript𝐶𝐹2subscript𝐶𝑂𝑉14subscript𝐶𝐾𝑄subscript𝐶𝑋subscript𝐶𝐸\displaystyle\alpha_{i}\coloneqq\prod_{j<i}(C_{F})^{2}C_{OV}(1+4C_{KQ})(C_{X}+% C_{E}).italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ ∏ start_POSTSUBSCRIPT italic_j < italic_i end_POSTSUBSCRIPT ( italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ( 1 + 4 italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT ) ( italic_C start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT + italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT ) .

According to the (LABEL:eq:K_est), (LABEL:eq:L_tau_est), (LABEL:eq:W_ov_est_inf), (LABEL:eq:W_ov_est_2), (LABEL:eq:W_kq_est_inf), (LABEL:eq:W_kq_est_2), (F.12), (F.13), (LABEL:eq:C_e_est) and (LABEL:eq:C_tau_est) in Section F.1.2, we derive the following with δ=𝒪(ϵ2/d)𝛿𝒪superscriptitalic-ϵ2𝑑\delta=\mathcal{O}(\epsilon^{2/d})italic_δ = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT 2 / italic_d end_POSTSUPERSCRIPT ) (Section E.4) and d=4𝑑4d=4italic_d = 4 (Theorem 3.1):

K=𝒪(ϵ2L),L𝒯=𝒪(d0Ls+),COV2,=𝒪(dϵ4L),COV=𝒪(ϵ4L),formulae-sequence𝐾𝒪superscriptitalic-ϵ2𝐿formulae-sequencesubscript𝐿𝒯𝒪subscript𝑑0subscript𝐿subscript𝑠formulae-sequencesuperscriptsubscript𝐶𝑂𝑉2𝒪𝑑superscriptitalic-ϵ4𝐿subscript𝐶𝑂𝑉𝒪superscriptitalic-ϵ4𝐿\displaystyle\leavevmode\nobreak\ K=\mathcal{O}\left(\epsilon^{-2L}\right),L_{% \mathcal{T}}=\mathcal{O}\left(d_{0}L_{s_{+}}\right),\leavevmode\nobreak\ C_{OV% }^{2,\infty}=\mathcal{O}(d\epsilon^{-4L}),\leavevmode\nobreak\ C_{OV}=\mathcal% {O}(\epsilon^{-4L}),italic_K = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 2 italic_L end_POSTSUPERSCRIPT ) , italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT = caligraphic_O ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT = caligraphic_O ( italic_d italic_ϵ start_POSTSUPERSCRIPT - 4 italic_L end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 4 italic_L end_POSTSUPERSCRIPT ) ,
CKQ2,=𝒪(ϵ4),CKQ=𝒪(ϵ4),CF2,=𝒪(ϵ4),CF=𝒪(ϵ2),CE=𝒪(L3/2),formulae-sequencesuperscriptsubscript𝐶𝐾𝑄2𝒪superscriptitalic-ϵ4formulae-sequencesubscript𝐶𝐾𝑄𝒪superscriptitalic-ϵ4formulae-sequencesuperscriptsubscript𝐶𝐹2𝒪superscriptitalic-ϵ4formulae-sequencesubscript𝐶𝐹𝒪superscriptitalic-ϵ2subscript𝐶𝐸𝒪superscript𝐿32\displaystyle\leavevmode\nobreak\ C_{KQ}^{2,\infty}=\mathcal{O}(\epsilon^{-4})% ,\leavevmode\nobreak\ C_{KQ}=\mathcal{O}(\epsilon^{-4}),\leavevmode\nobreak\ C% _{F}^{2,\infty}=\mathcal{O}(\epsilon^{-4}),\leavevmode\nobreak\ C_{F}=\mathcal% {O}(\epsilon^{-2}),\leavevmode\nobreak\ C_{E}=\mathcal{O}(L^{3/2}),italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_K italic_Q end_POSTSUBSCRIPT = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , ∞ end_POSTSUPERSCRIPT = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT = caligraphic_O ( italic_ϵ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ) , italic_C start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT = caligraphic_O ( italic_L start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT ) , (F.16)
C𝒯=𝒪(d0Ls+d0log(d0/T0)+log(1/ϵ)),rx=𝒪(d0logd0+logC𝒯+log(n/δ¯)).formulae-sequencesubscript𝐶𝒯𝒪subscript𝑑0subscript𝐿subscript𝑠subscript𝑑0subscript𝑑0subscript𝑇01italic-ϵsubscript𝑟𝑥𝒪subscript𝑑0subscript𝑑0subscript𝐶𝒯𝑛¯𝛿\displaystyle\leavevmode\nobreak\ C_{\mathcal{T}}=\mathcal{O}\left(d_{0}L_{s_{% +}}\cdot\sqrt{d_{0}\log(d_{0}/T_{0})+\log(1/\epsilon)}\right),\leavevmode% \nobreak\ r_{x}=\mathcal{O}\left(\sqrt{d_{0}\log d_{0}+\log C_{\mathcal{T}}+% \log(n/\bar{\delta})}\right).italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT = caligraphic_O ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) + roman_log ( start_ARG 1 / italic_ϵ end_ARG ) end_ARG ) , italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = caligraphic_O ( square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + roman_log italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + roman_log ( start_ARG italic_n / over¯ start_ARG italic_δ end_ARG end_ARG ) end_ARG ) .

We consider that each elements of the input data are within [0,1]01[0,1][ 0 , 1 ] as shown in Appendix E.

Recall that ι=1/(n1/4T0(TT0))𝜄1superscript𝑛14subscript𝑇0𝑇subscript𝑇0\iota=1/(n^{1/4}T_{0}(T-T_{0}))italic_ι = 1 / ( italic_n start_POSTSUPERSCRIPT 1 / 4 end_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ), then we get the log-covering number of 𝒯p2,1,4superscriptsubscript𝒯𝑝214\mathcal{T}_{p}^{2,1,4}caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT,

log𝒩(ι,𝒯p2,1,4,2)=𝒩𝜄superscriptsubscript𝒯𝑝214subscriptnorm2absent\displaystyle\leavevmode\nobreak\ \log\mathcal{N}\left(\iota,\mathcal{T}_{p}^{% 2,1,4},\norm{\cdot}_{2}\right)=roman_log caligraphic_N ( italic_ι , caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = 𝒪(ϵ8KLKd2log(nL)ι)𝒪superscriptitalic-ϵ8𝐾superscript𝐿𝐾superscript𝑑2𝑛𝐿𝜄\displaystyle\leavevmode\nobreak\ \mathcal{O}\left(\frac{\epsilon^{-8K}\cdot L% ^{K}d^{2}\log(nL)}{\iota}\right)caligraphic_O ( divide start_ARG italic_ϵ start_POSTSUPERSCRIPT - 8 italic_K end_POSTSUPERSCRIPT ⋅ italic_L start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( start_ARG italic_n italic_L end_ARG ) end_ARG start_ARG italic_ι end_ARG )
=\displaystyle== 𝒪(1)(28Klog(L/ϵ)d2log(nL)ι).𝒪1superscript28𝐾𝐿italic-ϵsuperscript𝑑2𝑛𝐿𝜄\displaystyle\leavevmode\nobreak\ \mathcal{O}(1)\cdot\left(\frac{2^{8K\log(L/% \epsilon)}d^{2}\log(nL)}{\iota}\right).caligraphic_O ( 1 ) ⋅ ( divide start_ARG 2 start_POSTSUPERSCRIPT 8 italic_K roman_log ( start_ARG italic_L / italic_ϵ end_ARG ) end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( start_ARG italic_n italic_L end_ARG ) end_ARG start_ARG italic_ι end_ARG ) .

Following (Chen et al., 2023a, Appendix B.2), then the log-covering number of 𝒮𝒯p2,1,4subscript𝒮superscriptsubscript𝒯𝑝214\mathcal{S}_{\mathcal{T}_{p}^{2,1,4}}caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is

log𝒩(ι,𝒮𝒯p2,1,4,2)𝒩𝜄subscript𝒮superscriptsubscript𝒯𝑝214subscriptnorm2\displaystyle\leavevmode\nobreak\ \log\mathcal{N}\left(\iota,\mathcal{S}_{% \mathcal{T}_{p}^{2,1,4}},\norm{\cdot}_{2}\right)roman_log caligraphic_N ( italic_ι , caligraphic_S start_POSTSUBSCRIPT caligraphic_T start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , 1 , 4 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )
=\displaystyle== 𝒪(2Dd0log(1+6C𝒯L𝒯d0(3rx+DlogD)T0ι)+28Klog(L/ϵ)d2log(nL)T02ι2)𝒪2𝐷subscript𝑑016subscript𝐶𝒯subscript𝐿𝒯subscript𝑑03subscript𝑟𝑥𝐷𝐷subscript𝑇0𝜄superscript28𝐾𝐿italic-ϵsuperscript𝑑2𝑛𝐿superscriptsubscript𝑇02superscript𝜄2\displaystyle\leavevmode\nobreak\ \mathcal{O}\left(2Dd_{0}\cdot\log\left(1+% \frac{6C_{\mathcal{T}}L_{\mathcal{T}}\sqrt{d_{0}}(3r_{x}+\sqrt{D\log D})}{T_{0% }\iota}\right)+\frac{2^{8K\log(L/\epsilon)}d^{2}\log(nL)}{T_{0}^{2}\iota^{2}}\right)caligraphic_O ( 2 italic_D italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⋅ roman_log ( 1 + divide start_ARG 6 italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ( 3 italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + square-root start_ARG italic_D roman_log italic_D end_ARG ) end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_ι end_ARG ) + divide start_ARG 2 start_POSTSUPERSCRIPT 8 italic_K roman_log ( start_ARG italic_L / italic_ϵ end_ARG ) end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( start_ARG italic_n italic_L end_ARG ) end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ι start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) (By (F.2.2))
=\displaystyle== 𝒪(n1/228(1/ϵ)Llog(L/ϵ)Dd2d06Ls+2(TT0)2log(nL))𝒪superscript𝑛12superscript28superscript1italic-ϵ𝐿𝐿italic-ϵ𝐷superscript𝑑2superscriptsubscript𝑑06superscriptsubscript𝐿subscript𝑠2superscript𝑇subscript𝑇02𝑛𝐿\displaystyle\leavevmode\nobreak\ \mathcal{O}\left(n^{1/2}2^{8(1/\epsilon)^{L}% \log(L/\epsilon)}Dd^{2}d_{0}^{6}L_{s_{+}}^{2}(T-T_{0})^{2}\cdot\log(nL)\right)caligraphic_O ( italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT 2 start_POSTSUPERSCRIPT 8 ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT roman_log ( start_ARG italic_L / italic_ϵ end_ARG ) end_POSTSUPERSCRIPT italic_D italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ roman_log ( start_ARG italic_n italic_L end_ARG ) ) (By (F.2.2))
=\displaystyle== 𝒪(n1/22(1/ϵ)2LDd2d06Ls+2(TT0)2log(nL))𝒪superscript𝑛12superscript2superscript1italic-ϵ2𝐿𝐷superscript𝑑2superscriptsubscript𝑑06superscriptsubscript𝐿subscript𝑠2superscript𝑇subscript𝑇02𝑛𝐿\displaystyle\leavevmode\nobreak\ \mathcal{O}\left(n^{1/2}2^{(1/\epsilon)^{2L}% }Dd^{2}d_{0}^{6}L_{s_{+}}^{2}(T-T_{0})^{2}\cdot\log(nL)\right)caligraphic_O ( italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT italic_D italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ roman_log ( start_ARG italic_n italic_L end_ARG ) ) (By (1/ϵ)L8log(L/ϵ)superscript1italic-ϵ𝐿8𝐿italic-ϵ(1/\epsilon)^{L}\geq 8\log(L/\epsilon)( 1 / italic_ϵ ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ≥ 8 roman_log ( start_ARG italic_L / italic_ϵ end_ARG ))
=\displaystyle== 𝒪~(n1/22(1/ϵ)2LDd2d06Ls+2(TT0)2)~𝒪superscript𝑛12superscript2superscript1italic-ϵ2𝐿𝐷superscript𝑑2superscriptsubscript𝑑06superscriptsubscript𝐿subscript𝑠2superscript𝑇subscript𝑇02\displaystyle\leavevmode\nobreak\ \widetilde{\mathcal{O}}\left(n^{1/2}2^{(1/% \epsilon)^{2L}}Dd^{2}d_{0}^{6}L_{s_{+}}^{2}(T-T_{0})^{2}\right)over~ start_ARG caligraphic_O end_ARG ( italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT italic_D italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (By ignoring the log factors)
=\displaystyle== 𝒪~(n1/22(1/ϵ)2LDd2d06Ls+2T2).~𝒪superscript𝑛12superscript2superscript1italic-ϵ2𝐿𝐷superscript𝑑2superscriptsubscript𝑑06superscriptsubscript𝐿subscript𝑠2superscript𝑇2\displaystyle\leavevmode\nobreak\ \widetilde{\mathcal{O}}\left(n^{1/2}2^{(1/% \epsilon)^{2L}}Dd^{2}d_{0}^{6}L_{s_{+}}^{2}T^{2}\right).over~ start_ARG caligraphic_O end_ARG ( italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT italic_D italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) .

Substituting the log-covering number into (F.2.2), we have

1TT0T0TsW^(,t)logpt()L2(Pt)2dt1𝑇subscript𝑇0superscriptsubscriptsubscript𝑇0𝑇superscriptsubscriptnormsubscript𝑠^𝑊𝑡subscript𝑝𝑡superscript𝐿2subscript𝑃𝑡2𝑡\displaystyle\leavevmode\nobreak\ \frac{1}{T-T_{0}}\int_{T_{0}}^{T}\norm{s_{% \widehat{W}}(\cdot,t)-\nabla\log p_{t}(\cdot)}_{L^{2}(P_{t})}^{2}\differential tdivide start_ARG 1 end_ARG start_ARG italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ start_ARG italic_s start_POSTSUBSCRIPT over^ start_ARG italic_W end_ARG end_POSTSUBSCRIPT ( ⋅ , italic_t ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_t
=\displaystyle== 𝒪(C𝒯2+rx2ϵ2nT0(TT0)(log(𝒩)+log(1/δ¯))+d02T0(TT0)ϵ2+1n)𝒪superscriptsubscript𝐶𝒯2superscriptsubscript𝑟𝑥2superscriptitalic-ϵ2𝑛subscript𝑇0𝑇subscript𝑇0𝒩1¯𝛿superscriptsubscript𝑑02subscript𝑇0𝑇subscript𝑇0superscriptitalic-ϵ21𝑛\displaystyle\leavevmode\nobreak\ \mathcal{O}\Big{(}\frac{C_{\mathcal{T}}^{2}+% r_{x}^{2}}{\epsilon^{2}nT_{0}(T-T_{0})}(\log(\mathcal{N})+\log(1/\bar{\delta})% )+\frac{d_{0}^{2}}{T_{0}(T-T_{0})}\epsilon^{2}+\frac{1}{n}\Big{)}caligraphic_O ( divide start_ARG italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_n italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ( roman_log ( start_ARG caligraphic_N end_ARG ) + roman_log ( start_ARG 1 / over¯ start_ARG italic_δ end_ARG end_ARG ) ) + divide start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG )
=\displaystyle== 𝒪(C𝒯2+rx2ϵ2nT0T(log(𝒩)+log(1/δ¯))1stterm+d02T0Tϵ22ndterm+1n).𝒪subscriptsuperscriptsubscript𝐶𝒯2superscriptsubscript𝑟𝑥2superscriptitalic-ϵ2𝑛subscript𝑇0𝑇𝒩1¯𝛿1sttermsubscriptsuperscriptsubscript𝑑02subscript𝑇0𝑇superscriptitalic-ϵ22ndterm1𝑛\displaystyle\leavevmode\nobreak\ \mathcal{O}\Big{(}\underbrace{\frac{C_{% \mathcal{T}}^{2}+r_{x}^{2}}{\epsilon^{2}nT_{0}T}(\log(\mathcal{N})+\log(1/\bar% {\delta}))}_{\mathrm{1st\leavevmode\nobreak\ term}}+\underbrace{\frac{d_{0}^{2% }}{T_{0}T}\epsilon^{2}}_{\mathrm{2nd\leavevmode\nobreak\ term}}+\frac{1}{n}% \Big{)}.caligraphic_O ( under⏟ start_ARG divide start_ARG italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_n italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG ( roman_log ( start_ARG caligraphic_N end_ARG ) + roman_log ( start_ARG 1 / over¯ start_ARG italic_δ end_ARG end_ARG ) ) end_ARG start_POSTSUBSCRIPT 1 roman_s roman_t roman_term end_POSTSUBSCRIPT + under⏟ start_ARG divide start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT 2 roman_n roman_d roman_term end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ) . (F.17)

Recall the following parameters,

  • C𝒯2=𝒪(d02Ls+2d0log(d0/T0)+log(1/ϵ))superscriptsubscript𝐶𝒯2𝒪superscriptsubscript𝑑02superscriptsubscript𝐿subscript𝑠2subscript𝑑0subscript𝑑0subscript𝑇01italic-ϵC_{\mathcal{T}}^{2}=\mathcal{O}(d_{0}^{2}L_{s_{+}}^{2}d_{0}\log(d_{0}/T_{0})+% \log(1/\epsilon))italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = caligraphic_O ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log ( start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) + roman_log ( start_ARG 1 / italic_ϵ end_ARG ) )

  • rx2=𝒪(d0logd0+logC𝒯+log(n/δ¯))superscriptsubscript𝑟𝑥2𝒪subscript𝑑0subscript𝑑0subscript𝐶𝒯𝑛¯𝛿r_{x}^{2}=\mathcal{O}(d_{0}\log d_{0}+\log C_{\mathcal{T}}+\log(n/\bar{\delta}))italic_r start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = caligraphic_O ( italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT roman_log italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + roman_log italic_C start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + roman_log ( start_ARG italic_n / over¯ start_ARG italic_δ end_ARG end_ARG ) )

  • δ¯¯𝛿\bar{\delta}over¯ start_ARG italic_δ end_ARG: probability error

  • ϵitalic-ϵ\epsilonitalic_ϵ: approximation error

  • n𝑛nitalic_n: sample size

  • T0<T/2subscript𝑇0𝑇2T_{0}<T/2italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < italic_T / 2

  • D,d,d0>1𝐷𝑑subscript𝑑01D,d,d_{0}>1italic_D , italic_d , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT > 1: feature dimension

  • L>1𝐿1L>1italic_L > 1: sequence length

  • d0=Ldsubscript𝑑0𝐿𝑑d_{0}=L\cdot ditalic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_L ⋅ italic_d

  • Ls+subscript𝐿subscript𝑠L_{s_{+}}italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT: Lipschitz coefficient

Ignoring the log\logroman_log factors, and poly(D,d,d0,LS+)poly𝐷𝑑subscript𝑑0subscript𝐿subscript𝑆\mathrm{poly}(D,d,d_{0},L_{S_{+}})roman_poly ( italic_D , italic_d , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_L start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT ), the first term in (F.17) becomes

1n1/2TT02(1/ϵ)2L.1superscript𝑛12𝑇subscript𝑇0superscript2superscript1italic-ϵ2𝐿\displaystyle\frac{1}{n^{1/2}}\cdot\frac{T}{T_{0}}\cdot 2^{(1/\epsilon)^{2L}}.divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG ⋅ divide start_ARG italic_T end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ⋅ 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT .

The second term simplifies to

1T0Tϵ2.1subscript𝑇0𝑇superscriptitalic-ϵ2\displaystyle\frac{1}{T_{0}T}\epsilon^{2}.divide start_ARG 1 end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .

Thus, the final bound is

O~(1n1/2TT02(1/ϵ)2L+1T0Tϵ2+1n).~𝑂1superscript𝑛12𝑇subscript𝑇0superscript2superscript1italic-ϵ2𝐿1subscript𝑇0𝑇superscriptitalic-ϵ21𝑛\displaystyle\widetilde{O}\Bigg{(}\frac{1}{n^{1/2}}\frac{T}{T_{0}}\cdot 2^{(1/% \epsilon)^{2L}}+\frac{1}{T_{0}T}\epsilon^{2}+\frac{1}{n}\Bigg{)}.over~ start_ARG italic_O end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG italic_T end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ⋅ 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ) .

Thus, we complete the proof of Corollary 3.1.1. ∎

F.3 Proof of Corollary 3.1.2

Our proof is built on (Chen et al., 2023a, Appendix C). The main difference between our work and (Chen et al., 2023a) is our score estimation error from Corollary 3.1.1. Consequently, only the subspace error and the total variation distance differ from (Chen et al., 2023a, Theorem 3).

Proof Sketch of (i).

We show that if the orthogonal score increases significantly, the mismatch between the column span of B𝐵Bitalic_B and WBsubscript𝑊𝐵W_{B}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT will be greatly amplified. Therefore, an accurate score network estimator forces B𝐵Bitalic_B and WBsubscript𝑊𝐵W_{B}italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT to align with each other.

Proof Sketch of (ii).

We conduct the proof via 2 steps:

  • Step 1: Total Variation Distance Bound. We obtain the discrete result from the continuous-time generated distribution P^T0subscript^𝑃subscript𝑇0\widehat{P}_{T_{0}}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT by adding discretization error (Chen et al., 2023a, Lemma 4). It suffices to bound the divergence between the following two stochastic processes:

    • For the ground-truth backward process, consider ht=Bytsuperscriptsubscript𝑡superscript𝐵topsubscript𝑦𝑡h_{t}^{\leftarrow}=B^{\top}y_{t}italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT = italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and the following SDE:

      dht=[12ht+logphTt(ht)]dt+dB¯th.superscriptsubscript𝑡delimited-[]12superscriptsubscript𝑡superscript𝑝𝑇𝑡superscriptsubscript𝑡𝑡superscriptsubscript¯𝐵𝑡\displaystyle\differential h_{t}^{\leftarrow}=\left[\frac{1}{2}h_{t}^{% \leftarrow}+\nabla\log p^{h}{T-t}(h_{t}^{\leftarrow})\right]\differential t+% \differential\bar{B}_{t}^{h}.start_DIFFOP roman_d end_DIFFOP italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT = [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT + ∇ roman_log italic_p start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT italic_T - italic_t ( italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ← end_POSTSUPERSCRIPT ) ] start_DIFFOP roman_d end_DIFFOP italic_t + start_DIFFOP roman_d end_DIFFOP over¯ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT .

      Denote the marginal distribution of the ground-truth process as PT0hsuperscriptsubscript𝑃subscript𝑇0P_{T_{0}}^{h}italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT.

    • For the learned process, consider h~t,rsubscriptsuperscript~𝑟𝑡{\widetilde{h}}^{\leftarrow,r}_{t}over~ start_ARG italic_h end_ARG start_POSTSUPERSCRIPT ← , italic_r end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and the following SDE:

      dh~t,r=[12h~t,r+s~f,Uh(h~t,r,Tt)]dt+dB¯th,subscriptsuperscript~𝑟𝑡delimited-[]12subscriptsuperscript~𝑟𝑡subscriptsuperscript~𝑠𝑓𝑈subscriptsuperscript~𝑟𝑡𝑇𝑡𝑡subscriptsuperscript¯𝐵𝑡\displaystyle\differential{\widetilde{h}}^{\leftarrow,r}_{t}=\left[\frac{1}{2}% {\widetilde{h}}^{\leftarrow,r}_{t}+\widetilde{s}^{h}_{f,U}({\widetilde{h}}^{% \leftarrow,r}_{t},T-t)\right]\differential t+\differential\bar{B}^{h}_{t},roman_d start_ARG over~ start_ARG italic_h end_ARG end_ARG start_POSTSUPERSCRIPT ← , italic_r end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG over~ start_ARG italic_h end_ARG start_POSTSUPERSCRIPT ← , italic_r end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + over~ start_ARG italic_s end_ARG start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_f , italic_U end_POSTSUBSCRIPT ( over~ start_ARG italic_h end_ARG start_POSTSUPERSCRIPT ← , italic_r end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_T - italic_t ) ] start_DIFFOP roman_d end_DIFFOP italic_t + start_DIFFOP roman_d end_DIFFOP over¯ start_ARG italic_B end_ARG start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ,

      where s~f,Uh(z,t)[Uf(Uz,t)z]/σ(t)superscriptsubscript~𝑠𝑓𝑈𝑧𝑡delimited-[]superscript𝑈top𝑓𝑈𝑧𝑡𝑧𝜎𝑡\widetilde{s}_{f,U}^{h}(z,t)\coloneqq[U^{\top}f(Uz,t)-z]/\sigma(t)over~ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_f , italic_U end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_z , italic_t ) ≔ [ italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_U italic_z , italic_t ) - italic_z ] / italic_σ ( italic_t ) and U𝑈Uitalic_U is an orthogonal matrix. Following the notation in (Chen et al., 2023a), we use (WBU)P^T0superscriptsubscriptsubscript𝑊𝐵𝑈topsubscript^𝑃subscript𝑇0(W_{B}U)_{\sharp}^{\top}\widehat{P}_{T_{0}}( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_U ) start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT to denote the marginal distribution of P^T0subscript^𝑃subscript𝑇0\widehat{P}_{T_{0}}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. We first calculate the latent score matching error, i.e., the error between logpth(h)subscriptsuperscript𝑝𝑡\nabla\log p^{h}_{t}(h)∇ roman_log italic_p start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_h ) and s~U,fh(h,t)superscriptsubscript~𝑠𝑈𝑓𝑡\widetilde{s}_{U,f}^{h}(h,t)over~ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_U , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_h , italic_t ). Then, we adopt Girsanov’s Theorem (Chen et al., 2023b) and bound the difference in the KL divergence of the above two processes to derive the score-matching error bound.

  • Step 2: Wasserstein-2 Distance Bound. We use the same technique as (Chen et al., 2023a, Theorem 3).

Proof Sketch of (iii).

We derive item (iii) by solving the orthogonal backward process of the diffusion model.

Next, we present the auxiliary theoretical results in Section F.3.1 to prepare our main proof of Corollary 3.1.2. Then we give detailed proof of Corollary 3.1.2 in Section F.3.2.

F.3.1 Auxiliary Lemmas

Here we include a few auxiliary lemmas from (Chen et al., 2023a) without proofs. Recall the definition of Lipschitz norm: for a given function f𝑓fitalic_f, f()Lip=supxy(f(x)f(y)2/xy2)subscriptnorm𝑓𝐿𝑖𝑝subscriptsupremum𝑥𝑦subscriptnorm𝑓𝑥𝑓𝑦2subscriptnorm𝑥𝑦2\norm{f(\cdot)}_{Lip}=\sup_{x\neq y}(\norm{f(x)-f(y)}_{2}/\norm{x-y}_{2})∥ start_ARG italic_f ( ⋅ ) end_ARG ∥ start_POSTSUBSCRIPT italic_L italic_i italic_p end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_x ≠ italic_y end_POSTSUBSCRIPT ( ∥ start_ARG italic_f ( italic_x ) - italic_f ( italic_y ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT / ∥ start_ARG italic_x - italic_y end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ).

Lemma F.7 (Lemma 3 of (Chen et al., 2023a)).

Assume that the following holds

𝔼hPhlogph(h)22Csh,λmin𝔼hPh[hh]c0,𝔼hPhh22Ch,formulae-sequencesubscript𝔼similar-tosubscript𝑃superscriptsubscriptnormsubscript𝑝22subscript𝐶𝑠formulae-sequencesubscript𝜆minsubscript𝔼similar-tosubscript𝑃delimited-[]superscripttopsubscript𝑐0subscript𝔼similar-tosubscript𝑃superscriptsubscriptnorm22subscript𝐶\displaystyle\mathbb{E}_{h\sim P_{h}}\norm{\nabla\log p_{h}(h)}_{2}^{2}\leq C_% {sh},\quad\lambda_{\rm min}\mathbb{E}_{h\sim P_{h}}[hh^{\top}]\geq c_{0},\quad% \mathbb{E}_{h\sim P_{h}}\norm{h}_{2}^{2}\leq C_{h},blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_ARG ∇ roman_log italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_s italic_h end_POSTSUBSCRIPT , italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_h italic_h start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] ≥ italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ,

where λminsubscript𝜆min\lambda_{\rm min}italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT denotes the smallest eigenvalue. We denote

𝔼¯[ϕ(,t)]=T0T1σ2(t)𝔼xPt[ϕ(,t)]𝑑t.¯𝔼delimited-[]italic-ϕ𝑡superscriptsubscriptsubscript𝑇0𝑇1superscript𝜎2𝑡subscript𝔼similar-to𝑥subscript𝑃𝑡delimited-[]italic-ϕ𝑡differential-d𝑡\displaystyle\bar{\mathbb{E}}[\phi(\cdot,t)]=\int_{T_{0}}^{T}\frac{1}{\sigma^{% 2}(t)}\mathbb{E}_{x\sim P_{t}}[\phi(\cdot,t)]dt.over¯ start_ARG blackboard_E end_ARG [ italic_ϕ ( ⋅ , italic_t ) ] = ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) end_ARG blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_ϕ ( ⋅ , italic_t ) ] italic_d italic_t .

We set T0min{2log(d0/Csh),1,2log(c0),c0}subscript𝑇02subscript𝑑0subscript𝐶𝑠12subscript𝑐0subscript𝑐0T_{0}\leq\min\{2\log(d_{0}/C_{sh}),1,2\log(c_{0}),c_{0}\}italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≤ roman_min { 2 roman_log ( start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_C start_POSTSUBSCRIPT italic_s italic_h end_POSTSUBSCRIPT end_ARG ) , 1 , 2 roman_log ( start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) , italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT } and Tmax{2log(Ch/d0),1}𝑇2subscript𝐶subscript𝑑01T\geq\max\{2\log(C_{h}/d_{0}),1\}italic_T ≥ roman_max { 2 roman_log ( start_ARG italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT / italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) , 1 }. Suppose we have

𝔼¯WBf(WBx,t)Bq(Bx,t)22ϵ.¯𝔼superscriptsubscriptnormsubscript𝑊𝐵𝑓superscriptsubscript𝑊𝐵top𝑥𝑡𝐵𝑞superscript𝐵top𝑥𝑡22italic-ϵ\displaystyle\bar{\mathbb{E}}\norm{W_{B}f(W_{B}^{\top}x,t)-Bq(B^{\top}x,t)}_{2% }^{2}\leq\epsilon.over¯ start_ARG blackboard_E end_ARG ∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_f ( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) - italic_B italic_q ( italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_ϵ .

Then we have

WBWBBBF2=𝒪(ϵT0/c0),superscriptsubscriptnormsubscript𝑊𝐵superscriptsubscript𝑊𝐵top𝐵superscript𝐵topF2𝒪italic-ϵsubscript𝑇0subscript𝑐0\displaystyle\norm{W_{B}W_{B}^{\top}-BB^{\top}}_{\rm F}^{2}=\mathcal{O}(% \epsilon T_{0}/c_{0}),∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT roman_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = caligraphic_O ( italic_ϵ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ,

and there exists an orthorgonal matrix Ud0×d0𝑈superscriptsubscript𝑑0subscript𝑑0U\in\mathbb{R}^{d_{0}\times d_{0}}italic_U ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, such that:

𝔼¯Uf(Uh,t)q(h,t)22¯𝔼superscriptsubscriptnormsuperscript𝑈top𝑓𝑈𝑡𝑞𝑡22\displaystyle\quad\bar{\mathbb{E}}\norm{U^{\top}f(Uh,t)-q(h,t)}_{2}^{2}over¯ start_ARG blackboard_E end_ARG ∥ start_ARG italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_U italic_h , italic_t ) - italic_q ( italic_h , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=ϵ𝒪(1+T0c0[(TlogT0)d0maxtf(,t)Lip2+Csh]+maxtf(,t)Lip2Chc0).absentitalic-ϵ𝒪1subscript𝑇0subscript𝑐0delimited-[]𝑇subscript𝑇0subscript𝑑0subscript𝑡superscriptsubscriptnorm𝑓𝑡Lip2subscript𝐶𝑠subscript𝑡superscriptsubscriptnorm𝑓𝑡Lip2subscript𝐶subscript𝑐0\displaystyle=\epsilon\cdot\mathcal{O}\left(1+\frac{T_{0}}{c_{0}}\left[(T-\log T% _{0})d_{0}\cdot\max_{t}\norm{f(\cdot,t)}_{\rm Lip}^{2}+C_{s}h\right]+\frac{% \max_{t}\norm{f(\cdot,t)}_{\rm Lip}^{2}\cdot C_{h}}{c_{0}}\right).= italic_ϵ ⋅ caligraphic_O ( 1 + divide start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG [ ( italic_T - roman_log italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⋅ roman_max start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_ARG italic_f ( ⋅ , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT roman_Lip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_C start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT italic_h ] + divide start_ARG roman_max start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_ARG italic_f ( ⋅ , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT roman_Lip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .
Lemma F.8 (Lemma 4 of (Chen et al., 2023a)).

Assume that Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is sub-Gaussian, f(h,t)𝑓𝑡f(h,t)italic_f ( italic_h , italic_t ) and logpth(h)superscriptsubscript𝑝𝑡\nabla\log p_{t}^{h}(h)∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_h ) are Lipschitz in both hhitalic_h and t𝑡titalic_t. Assume we have the latent score matching error bound

T0T𝔼hPths~U,fh(ht,t)logpth(ht)22dtϵlatent (TT0).superscriptsubscriptsubscript𝑇0𝑇subscript𝔼similar-tosuperscriptsubscript𝑃𝑡superscriptsubscriptnormsuperscriptsubscript~𝑠𝑈𝑓subscript𝑡𝑡superscriptsubscript𝑝𝑡subscript𝑡22differential-d𝑡subscriptitalic-ϵlatent 𝑇subscript𝑇0\displaystyle\int_{T_{0}}^{T}\mathbb{E}_{h\sim P_{t}^{h}}\left\|\widetilde{s}_% {U,f}^{h}\left(h_{t},t\right)-\nabla\log p_{t}^{h}\left(h_{t}\right)\right\|_{% 2}^{2}\mathrm{\leavevmode\nobreak\ d}t\leq\epsilon_{\text{latent }}(T-T_{0}).∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ over~ start_ARG italic_s end_ARG start_POSTSUBSCRIPT italic_U , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_d italic_t ≤ italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) .

Then we have the following latent distribution estimation error for the undiscretized backward SDE

TV(PT0h,P^T0h)ϵlatent (TT0)+KL(PhN(0,Id0))exp(T).less-than-or-similar-toTVsuperscriptsubscript𝑃subscript𝑇0superscriptsubscript^𝑃subscript𝑇0subscriptitalic-ϵlatent 𝑇subscript𝑇0KLconditionalsubscript𝑃𝑁0subscript𝐼subscript𝑑0𝑇\operatorname{TV}\left(P_{T_{0}}^{h},\widehat{P}_{T_{0}}^{h}\right)\lesssim% \sqrt{\epsilon_{\text{latent }}(T-T_{0})}+\sqrt{\mathrm{KL}\left(P_{h}\|N\left% (0,I_{d_{0}}\right)\right)}\cdot\exp(-T).roman_TV ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ) ≲ square-root start_ARG italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG + square-root start_ARG roman_KL ( italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∥ italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ) end_ARG ⋅ roman_exp ( start_ARG - italic_T end_ARG ) .

Furthermore, we have the following latent distribution estimation error for the discretized backward SDE

TV(PT0h,P^T0h,dis)ϵlatent(TT0)+KL(PhN(0,Id0))exp(T)+ϵdis(TT0),less-than-or-similar-toTVsuperscriptsubscript𝑃subscript𝑇0superscriptsubscript^𝑃subscript𝑇0dissubscriptitalic-ϵlatent𝑇subscript𝑇0KLconditionalsubscript𝑃𝑁0subscript𝐼subscript𝑑0𝑇subscriptitalic-ϵdis𝑇subscript𝑇0\operatorname{TV}\left(P_{T_{0}}^{h},\widehat{P}_{T_{0}}^{h,\mathrm{dis}}% \right)\lesssim\sqrt{\epsilon_{\text{latent}}(T-T_{0})}+\sqrt{\mathrm{KL}\left% (P_{h}\|N\left(0,I_{d_{0}}\right)\right)}\cdot\exp(-T)+\sqrt{\epsilon_{\text{% dis}}(T-T_{0})},roman_TV ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h , roman_dis end_POSTSUPERSCRIPT ) ≲ square-root start_ARG italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG + square-root start_ARG roman_KL ( italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∥ italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ) end_ARG ⋅ roman_exp ( start_ARG - italic_T end_ARG ) + square-root start_ARG italic_ϵ start_POSTSUBSCRIPT dis end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ,

where

ϵdis=subscriptitalic-ϵdisabsent\displaystyle\epsilon_{\rm dis}=italic_ϵ start_POSTSUBSCRIPT roman_dis end_POSTSUBSCRIPT = (maxhf(h,)Lip σ(T0)+maxh,tf(h,t)2T02)2η2superscriptsubscriptsubscriptnorm𝑓Lip 𝜎subscript𝑇0subscript𝑡subscriptnorm𝑓𝑡2superscriptsubscript𝑇022superscript𝜂2\displaystyle\left(\frac{\max_{h}\left\|f(h,\cdot)\right\|_{\text{Lip }}}{% \sigma\left(T_{0}\right)}+\frac{\max_{h,t}\left\|f(h,t)\right\|_{2}}{T_{0}^{2}% }\right)^{2}\eta^{2}( divide start_ARG roman_max start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∥ italic_f ( italic_h , ⋅ ) ∥ start_POSTSUBSCRIPT Lip end_POSTSUBSCRIPT end_ARG start_ARG italic_σ ( italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG + divide start_ARG roman_max start_POSTSUBSCRIPT italic_h , italic_t end_POSTSUBSCRIPT ∥ italic_f ( italic_h , italic_t ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+(maxtf(,t)Lip σ(T0))2η2max{𝔼h02,d0}+ηd0,superscriptsubscript𝑡subscriptnorm𝑓𝑡Lip 𝜎subscript𝑇02superscript𝜂2𝔼superscriptnormsubscript02subscript𝑑0𝜂subscript𝑑0\displaystyle+\left(\frac{\max_{t}\left\|f(\cdot,t)\right\|_{\text{Lip }}}{% \sigma\left(T_{0}\right)}\right)^{2}\eta^{2}\max\left\{\mathbb{E}\left\|h_{0}% \right\|^{2},d_{0}\right\}+\eta d_{0},+ ( divide start_ARG roman_max start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ italic_f ( ⋅ , italic_t ) ∥ start_POSTSUBSCRIPT Lip end_POSTSUBSCRIPT end_ARG start_ARG italic_σ ( italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_max { blackboard_E ∥ italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT } + italic_η italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ,

and η𝜂\etaitalic_η is the step size in the backward process.

Lemma F.9 (Lemma 6 of (Chen et al., 2023a)).

Consider the following discretized SDE with step size μ𝜇\muitalic_μ satisfying TT0=KTμ𝑇subscript𝑇0subscript𝐾𝑇𝜇T-T_{0}=K_{T}\muitalic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_K start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT italic_μ

dyt=[121σ(Tkμ)]ykμdt+dBt, for t[kμ,(k+1)μ),formulae-sequencedsubscript𝑦𝑡delimited-[]121𝜎𝑇𝑘𝜇subscript𝑦𝑘𝜇d𝑡dsubscript𝐵𝑡 for 𝑡𝑘𝜇𝑘1𝜇\mathrm{d}y_{t}=\left[\frac{1}{2}-\frac{1}{\sigma(T-k\mu)}\right]{y}_{k\mu}% \mathrm{d}t+\mathrm{d}{B}_{t},\text{ for }t\in[k\mu,(k+1)\mu),roman_d italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG - divide start_ARG 1 end_ARG start_ARG italic_σ ( italic_T - italic_k italic_μ ) end_ARG ] italic_y start_POSTSUBSCRIPT italic_k italic_μ end_POSTSUBSCRIPT roman_d italic_t + roman_d italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , for italic_t ∈ [ italic_k italic_μ , ( italic_k + 1 ) italic_μ ) ,

where Y0N(0,I)similar-tosubscript𝑌0N0𝐼{Y}_{0}\sim\mathrm{N}(0,I)italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ roman_N ( 0 , italic_I ). Then when T>1𝑇1T>1italic_T > 1 and T0+μ1subscript𝑇0𝜇1T_{0}+\mu\leq 1italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_μ ≤ 1, we have YTT0N(0,σ2I)similar-tosubscript𝑌𝑇subscript𝑇0N0superscript𝜎2𝐼{Y}_{T-T_{0}}\sim\mathrm{N}\left(0,\sigma^{2}I\right)italic_Y start_POSTSUBSCRIPT italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∼ roman_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) with σ2e(T0+μ)superscript𝜎2𝑒subscript𝑇0𝜇\sigma^{2}\leq e\left(T_{0}+\mu\right)italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_e ( italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_μ ).

Lemma F.10 (Lemma 10 in (Chen et al., 2023a)).

Assume that logph(h)subscript𝑝\nabla\log p_{h}(h)∇ roman_log italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) is Lhsubscript𝐿L_{h}italic_L start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT-Lipschitz. Then we have 𝔼hPhlogph(h)22d0Lhsubscript𝔼similar-tosubscript𝑃superscriptsubscriptnormsubscript𝑝22subscript𝑑0subscript𝐿\mathbb{E}_{h\sim P_{h}}\left\|\nabla\log p_{h}(h)\right\|_{2}^{2}\leq d_{0}L_% {h}blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ ∇ roman_log italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_h ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT.

F.3.2 Main Proof of Corollary 3.1.2
Proof.

Recall

ξ(n,ϵ,L):=1n1/2TT02(1/ϵ)2L+1T0Tϵ2+1n.assign𝜉𝑛italic-ϵ𝐿1superscript𝑛12𝑇subscript𝑇0superscript2superscript1italic-ϵ2𝐿1subscript𝑇0𝑇superscriptitalic-ϵ21𝑛\displaystyle\xi(n,\epsilon,L):=\frac{1}{n^{1/2}}\frac{T}{T_{0}}\cdot 2^{(1/% \epsilon)^{2L}}+\frac{1}{T_{0}T}\epsilon^{2}+\frac{1}{n}.italic_ξ ( italic_n , italic_ϵ , italic_L ) := divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG italic_T end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ⋅ 2 start_POSTSUPERSCRIPT ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T end_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG .
  • Proof of (i). With Lemma F.7, we replace ϵitalic-ϵ\epsilonitalic_ϵ to be ϵ(TT0)italic-ϵ𝑇subscript𝑇0\epsilon(T-T_{0})italic_ϵ ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and we set Csh=Lhd0subscript𝐶𝑠subscript𝐿subscript𝑑0C_{sh}=L_{h}d_{0}italic_C start_POSTSUBSCRIPT italic_s italic_h end_POSTSUBSCRIPT = italic_L start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT by Lemma F.10, we have

    WBWBBBF2=𝒪(T0ξ(n,ϵ,L)c0).superscriptsubscriptnormsubscript𝑊𝐵superscriptsubscript𝑊𝐵top𝐵superscript𝐵top𝐹2𝒪subscript𝑇0𝜉𝑛italic-ϵ𝐿subscript𝑐0\displaystyle\norm{W_{B}W_{B}^{\top}-BB^{\top}}_{F}^{2}=\mathcal{O}\Bigg{(}% \frac{T_{0}\xi(n,\epsilon,L)}{c_{0}}\Bigg{)}.∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = caligraphic_O ( divide start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_ξ ( italic_n , italic_ϵ , italic_L ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .

    We substitute the score estimation error in Corollary 3.1.1 and T=𝒪(logn)𝑇𝒪𝑛T=\mathcal{O}(\log n)italic_T = caligraphic_O ( roman_log italic_n ) into the bound above, we deduce

    WBWBBBF2=𝒪~(1c0nζ(n)log3n),superscriptsubscriptnormsubscript𝑊𝐵superscriptsubscript𝑊𝐵top𝐵superscript𝐵top𝐹2~𝒪1subscript𝑐0superscript𝑛𝜁𝑛superscript3𝑛\displaystyle\norm{W_{B}W_{B}^{\top}-BB^{\top}}_{F}^{2}=\widetilde{\mathcal{O}% }\left(\frac{1}{c_{0}}n^{-\zeta(n)}\cdot\log^{3}n\right),∥ start_ARG italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_B italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = over~ start_ARG caligraphic_O end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG italic_n start_POSTSUPERSCRIPT - italic_ζ ( italic_n ) end_POSTSUPERSCRIPT ⋅ roman_log start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_n ) ,

    where ζ1(n)=1/29L2Ln(2L2L+1/(37Llogn))/(37logn)subscript𝜁1𝑛129superscript𝐿2𝐿superscript𝑛2superscript𝐿2𝐿137𝐿𝑛37𝑛\zeta_{1}(n)=1/2-9L^{2L}\cdot n^{(2L^{2L+1}/(37L\cdot\log n))}/(37\log n)italic_ζ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_n ) = 1 / 2 - 9 italic_L start_POSTSUPERSCRIPT 2 italic_L end_POSTSUPERSCRIPT ⋅ italic_n start_POSTSUPERSCRIPT ( 2 italic_L start_POSTSUPERSCRIPT 2 italic_L + 1 end_POSTSUPERSCRIPT / ( 37 italic_L ⋅ roman_log italic_n ) ) end_POSTSUPERSCRIPT / ( 37 roman_log italic_n ).

    We note that logn𝑛\log nroman_log italic_n is great enough to make T𝑇Titalic_T satisfies Tmax{log(Ch/d0+1),1}𝑇subscript𝐶subscript𝑑011T\geq\max\{\log(C_{h}/d_{0}+1),1\}italic_T ≥ roman_max { roman_log ( start_ARG italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT / italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + 1 end_ARG ) , 1 } where Ch𝔼hPhh22subscript𝐶subscript𝔼similar-tosubscript𝑃superscriptsubscriptnorm22C_{h}\geq\mathbb{E}_{h\sim P_{h}}\norm{h}_{2}^{2}italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ≥ blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_ARG italic_h end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

  • Proof of (ii). Lemma F.7 and Lemma F.10 imply that

    𝔼¯Uf(Uh,t)q(h,t)22=𝒪(ϵlatent(TT0)),¯𝔼superscriptsubscriptnormsuperscript𝑈top𝑓𝑈𝑡𝑞𝑡22𝒪subscriptitalic-ϵlatent𝑇subscript𝑇0\displaystyle\bar{\mathbb{E}}\norm{U^{\top}f(Uh,t)-q(h,t)}_{2}^{2}=\mathcal{O}% (\epsilon_{\text{latent}}(T-T_{0})),over¯ start_ARG blackboard_E end_ARG ∥ start_ARG italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_U italic_h , italic_t ) - italic_q ( italic_h , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = caligraphic_O ( italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ,

    where

    ϵlatent=ϵ𝒪(T0c0[(TlogT0)d0Ls+2+d0Lh]+Ls+2Chc0).subscriptitalic-ϵlatentitalic-ϵ𝒪subscript𝑇0subscript𝑐0delimited-[]𝑇subscript𝑇0subscript𝑑0superscriptsubscript𝐿subscript𝑠2subscript𝑑0subscript𝐿superscriptsubscript𝐿subscript𝑠2subscript𝐶subscript𝑐0\displaystyle\epsilon_{\text{latent}}=\epsilon\cdot\mathcal{O}\left(\frac{T_{0% }}{c_{0}}\left[(T-\log T_{0})d_{0}\cdot L_{s_{+}}^{2}+d_{0}L_{h}\right]+\frac{% L_{s_{+}}^{2}\cdot C_{h}}{c_{0}}\right).italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT = italic_ϵ ⋅ caligraphic_O ( divide start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG [ ( italic_T - roman_log italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⋅ italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ] + divide start_ARG italic_L start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ italic_C start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .

    Through the algebra calculation, we get

    \macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111Uf(Uh,t)q(h,t)22\macc@depthΔ\frozen@everymath\macc@group\macc@set@skewchar\macc@nested@a111superscriptsubscriptnormsuperscript𝑈top𝑓𝑈𝑡𝑞𝑡22\displaystyle\macc@depth\char 1\relax\frozen@everymath{\macc@group}% \macc@set@skewchar\macc@nested@a 111{}\norm{U^{\top}f(Uh,t)-q(h,t)}_{2}^{2}roman_Δ 111 ∥ start_ARG italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_U italic_h , italic_t ) - italic_q ( italic_h , italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT =T0T𝔼hPthUf(Uh,t)hσ(t)logpth(h)22dtabsentsuperscriptsubscriptsubscript𝑇0𝑇subscript𝔼similar-tosuperscriptsubscript𝑃𝑡superscriptsubscriptnormsuperscript𝑈top𝑓𝑈𝑡𝜎𝑡superscriptsubscript𝑝𝑡22𝑡\displaystyle=\int_{T_{0}}^{T}\mathbb{E}_{h\sim P_{t}^{h}}\norm{\frac{U^{\top}% f(Uh,t)-h}{\sigma(t)}-\nabla\log p_{t}^{h}(h)}_{2}^{2}\differential t= ∫ start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_h ∼ italic_P start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_ARG divide start_ARG italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_U italic_h , italic_t ) - italic_h end_ARG start_ARG italic_σ ( italic_t ) end_ARG - ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT ( italic_h ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_DIFFOP roman_d end_DIFFOP italic_t
    ϵlatent(TT0).absentsubscriptitalic-ϵlatent𝑇subscript𝑇0\displaystyle\leq\epsilon_{\text{latent}}(T-T_{0}).≤ italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) .

    With ϵlatentsubscriptitalic-ϵlatent\epsilon_{\text{latent}}italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT and Lemma F.8, we obtain

    𝖳𝖵(PT0h,(WBU)P^T0dis)𝖳𝖵superscriptsubscript𝑃subscript𝑇0subscriptsuperscriptsubscript𝑊𝐵𝑈topsuperscriptsubscript^𝑃subscript𝑇0dis\displaystyle\leavevmode\nobreak\ {\sf TV}(P_{T_{0}}^{h},(W_{B}U)^{\top}_{% \sharp}\widehat{P}_{T_{0}}^{\rm dis})sansserif_TV ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , ( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_U ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_dis end_POSTSUPERSCRIPT )
    less-than-or-similar-to\displaystyle\lesssim ϵlatent (TT0)+KL(PhN(0,Id0))exp(T)+ϵdis (TT0)subscriptitalic-ϵlatent 𝑇subscript𝑇0KLconditionalsubscript𝑃𝑁0subscript𝐼subscript𝑑0𝑇subscriptitalic-ϵdis 𝑇subscript𝑇0\displaystyle\leavevmode\nobreak\ \sqrt{\epsilon_{\text{latent }}(T-T_{0})}+% \sqrt{\mathrm{KL}\left(P_{h}\|N\left(0,I_{d_{0}}\right)\right)}\exp(-T)+\sqrt{% \epsilon_{\text{dis }}(T-T_{0})}square-root start_ARG italic_ϵ start_POSTSUBSCRIPT latent end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG + square-root start_ARG roman_KL ( italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∥ italic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ) end_ARG roman_exp ( start_ARG - italic_T end_ARG ) + square-root start_ARG italic_ϵ start_POSTSUBSCRIPT dis end_POSTSUBSCRIPT ( italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG
    =\displaystyle== 𝒪~(1c0ξ(n,ϵ,L)+1n+μd02logd0T02+μd0).~𝒪1subscript𝑐0𝜉𝑛italic-ϵ𝐿1𝑛𝜇superscriptsubscript𝑑02subscript𝑑0superscriptsubscript𝑇02𝜇subscript𝑑0\displaystyle\leavevmode\nobreak\ \widetilde{\mathcal{O}}\left(\frac{1}{\sqrt{% c_{0}}}\sqrt{\xi(n,\epsilon,L)}+\frac{1}{n}+\mu\frac{\sqrt{d_{0}^{2}\log d_{0}% }}{T_{0}^{2}}+\sqrt{\mu}\sqrt{d_{0}}\right).over~ start_ARG caligraphic_O end_ARG ( divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG end_ARG square-root start_ARG italic_ξ ( italic_n , italic_ϵ , italic_L ) end_ARG + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG + italic_μ divide start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + square-root start_ARG italic_μ end_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .

    As we choose time step μ=𝒪(ξ(n,ϵ,L)T02/d0logd0)𝜇𝒪𝜉𝑛italic-ϵ𝐿superscriptsubscript𝑇02subscript𝑑0subscript𝑑0\mu=\mathcal{O}(\xi(n,\epsilon,L)\cdot T_{0}^{2}/d_{0}\sqrt{\log d_{0}})italic_μ = caligraphic_O ( italic_ξ ( italic_n , italic_ϵ , italic_L ) ⋅ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT square-root start_ARG roman_log italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ), we obtain

    𝖳𝖵(PT0h,(WBU)P^T0dis)=𝒪~(ξ(n,ϵ,L)).𝖳𝖵superscriptsubscript𝑃subscript𝑇0subscriptsuperscriptsubscript𝑊𝐵𝑈topsuperscriptsubscript^𝑃subscript𝑇0dis~𝒪𝜉𝑛italic-ϵ𝐿\displaystyle{\sf TV}(P_{T_{0}}^{h},(W_{B}U)^{\top}_{\sharp}\widehat{P}_{T_{0}% }^{\rm dis})=\widetilde{\mathcal{O}}\left(\sqrt{\xi(n,\epsilon,L)}\right).sansserif_TV ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , ( italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT italic_U ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_dis end_POSTSUPERSCRIPT ) = over~ start_ARG caligraphic_O end_ARG ( square-root start_ARG italic_ξ ( italic_n , italic_ϵ , italic_L ) end_ARG ) .

    By definition, P^T0h,dis=(UWB)P^T0dissuperscriptsubscript^𝑃subscript𝑇0dissuperscriptsubscript𝑈subscript𝑊𝐵topsuperscriptsubscript^𝑃subscript𝑇0dis\widehat{P}_{T_{0}}^{h,{\rm dis}}=(UW_{B})_{\sharp}^{\top}\widehat{P}_{T_{0}}^% {\rm dis}over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h , roman_dis end_POSTSUPERSCRIPT = ( italic_U italic_W start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT ♯ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_dis end_POSTSUPERSCRIPT. This completes the proof of the total variation distance in (3.2).

    For Wasserstein-2 distance 𝖶2(PT0h,Ph)subscript𝖶2superscriptsubscript𝑃subscript𝑇0subscript𝑃{\sf W}_{2}(P_{T_{0}}^{h},P_{h})sansserif_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ), we bound it by using the same technique as (Chen et al., 2023b, Lemma 16). Specifically, our proof only requires finite second moment of Phsubscript𝑃P_{h}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT verified in Assumption 2.2. As a result, we have

    𝖶2(PT0h,Ph)=𝒪(d0T0).subscript𝖶2superscriptsubscript𝑃subscript𝑇0subscript𝑃𝒪subscript𝑑0subscript𝑇0\displaystyle{\sf W}_{2}(P_{T_{0}}^{h},P_{h})=\mathcal{O}\left(\sqrt{d_{0}T_{0% }}\right).sansserif_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_P start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) = caligraphic_O ( square-root start_ARG italic_d start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) .
  • Proof of (iii). We apply Lemma F.9 due to our score decomposition. With the marginal distribution at time TT0𝑇subscript𝑇0T-T_{0}italic_T - italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and observing μT0much-less-than𝜇subscript𝑇0\mu\ll T_{0}italic_μ ≪ italic_T start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, we obtain the last property.

This completes the proof. ∎

Appendix G Proofs of Section 4

Our proofs are motivated by the observation of low-rank gradient decomposition in transformer-like models (Alman and Song, 2024a; Gu et al., 2024). With our simplifications and observations made in Section 4, we utilize the fine-grained complexity results of transformer and attention (Hu et al., 2024c; Alman and Song, 2024b, 2023) and tensor trick (Lemma D.1 and (Diao et al., 2019, 2018)) to proceed our proofs. Specifically, we approximate DiT training gradients with a series of low-rank approximations in Sections G.1.1, G.1.2 and G.1.3, and carefully match the multiplication dimensions so that the computation of dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG forms a chained low-rank approximation in Section G.2.

G.1 Auxiliary Theoretical Results for Theorem 4.1

Here we present some auxiliary theoretical results to prepare our main proof of the Existence of almost-linear Time Algorithms for ADITGC Theorem 4.1.

G.1.1 Low-Rank Decomposition of DiT Gradients

We start by some definitions. Recall that Wd×d𝑊superscript𝑑𝑑W\in\mathbb{R}^{d\times d}italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT and W¯d2¯𝑊superscriptsuperscript𝑑2\underline{W}\in\mathbb{R}^{d^{2}}under¯ start_ARG italic_W end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT denotes the vectorization of Wd×d𝑊superscript𝑑𝑑W\in\mathbb{R}^{d\times d}italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT following Definition D.1.

Definition G.1.

Let A1,A2d×Lsubscript𝐴1subscript𝐴2superscript𝑑𝐿A_{1},A_{2}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT be two matrices. Suppose 𝖠=A1A2L2×d2𝖠tensor-productsuperscriptsubscript𝐴1topsuperscriptsubscript𝐴2topsuperscriptsuperscript𝐿2superscript𝑑2\operatorname{\mathsf{A}}=A_{1}^{\top}\otimes A_{2}^{\top}\in\mathbb{R}^{L^{2}% \times d^{2}}sansserif_A = italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⊗ italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. Define 𝖠j0L×d2subscript𝖠subscript𝑗0superscript𝐿superscript𝑑2\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{L\times d^{2}}sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT as an L×d2𝐿superscript𝑑2L\times d^{2}italic_L × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT sub-block of 𝖠𝖠\operatorname{\mathsf{A}}sansserif_A. There are L𝐿Litalic_L such sub-blocks in total. For each j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], define the function u(W¯)j0:d2L:𝑢subscript¯𝑊subscript𝑗0superscriptsuperscript𝑑2superscript𝐿u(\underline{W})_{j_{0}}:\mathbb{R}^{d^{2}}\to\mathbb{R}^{L}italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT by u(W¯)j0:=exp(𝖠j0W¯)Lassign𝑢subscript¯𝑊subscript𝑗0subscript𝖠subscript𝑗0¯𝑊superscript𝐿u(\underline{W})_{j_{0}}:=\exp(\operatorname{\mathsf{A}}_{j_{0}}\underline{W})% \in\mathbb{R}^{L}italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT := roman_exp ( start_ARG sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT.

Definition G.2.

Let A1,A2d×Lsubscript𝐴1subscript𝐴2superscript𝑑𝐿A_{1},A_{2}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT be two matrices. Suppose 𝖠=A1A2L2×d2𝖠tensor-productsuperscriptsubscript𝐴1topsuperscriptsubscript𝐴2topsuperscriptsuperscript𝐿2superscript𝑑2\operatorname{\mathsf{A}}=A_{1}^{\top}\otimes A_{2}^{\top}\in\mathbb{R}^{L^{2}% \times d^{2}}sansserif_A = italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⊗ italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. Define 𝖠j0L×d2subscript𝖠subscript𝑗0superscript𝐿superscript𝑑2\operatorname{\mathsf{A}}_{j_{0}}\in\mathbb{R}^{L\times d^{2}}sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT as an L×d2𝐿superscript𝑑2L\times d^{2}italic_L × italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT sub-block of 𝖠𝖠\operatorname{\mathsf{A}}sansserif_A. There are L𝐿Litalic_L such sub-blocks in total. For every index j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], consider the function α(W¯)j0:d2:𝛼subscript¯𝑊subscript𝑗0superscriptsuperscript𝑑2\alpha(\underline{W})_{j_{0}}:\mathbb{R}^{d^{2}}\to\mathbb{R}italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R defined by α(W¯)j0:=exp(𝖠j0W¯)L×1,𝟙LL×1assign𝛼subscript¯𝑊subscript𝑗0subscriptsubscript𝖠subscript𝑗0¯𝑊𝐿1subscriptsubscript1𝐿𝐿1\alpha(\underline{W})_{j_{0}}:=\langle\underbrace{\exp(\operatorname{\mathsf{A% }}_{j_{0}}\underline{W})}_{L\times 1},\underbrace{\mathds{1}_{L}}_{L\times 1}\rangleitalic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT := ⟨ under⏟ start_ARG roman_exp ( start_ARG sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG end_ARG ) end_ARG start_POSTSUBSCRIPT italic_L × 1 end_POSTSUBSCRIPT , under⏟ start_ARG blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × 1 end_POSTSUBSCRIPT ⟩.

Definition G.3.

Suppose that α(W¯)j0𝛼subscript¯𝑊subscript𝑗0\alpha(\underline{W})_{j_{0}}\in\mathbb{R}italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R and u(W¯)j0L𝑢subscript¯𝑊subscript𝑗0superscript𝐿u(\underline{W})_{j_{0}}\in\mathbb{R}^{L}italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT are defined as in Definitions G.2 and G.1, respectively. For a fixed j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], consider the function f(W¯)j0:d2L:𝑓subscript¯𝑊subscript𝑗0superscriptsuperscript𝑑2superscript𝐿f(\underline{W})_{j_{0}}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R}^{L}italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT defined by

f(W¯)j0:=α(W¯)j01scalaru(W¯)j0L×1.assign𝑓subscript¯𝑊subscript𝑗0subscript𝛼superscriptsubscript¯𝑊subscript𝑗01scalarsubscript𝑢subscript¯𝑊subscript𝑗0𝐿1\displaystyle f(\underline{W})_{j_{0}}:=\underbrace{\alpha(\underline{W})_{j_{% 0}}^{-1}}_{\mathrm{scalar}}\underbrace{u(\underline{W})_{j_{0}}}_{L\times 1}.italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT := under⏟ start_ARG italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT roman_scalar end_POSTSUBSCRIPT under⏟ start_ARG italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × 1 end_POSTSUBSCRIPT .

Define f(W¯)L×L𝑓¯𝑊superscript𝐿𝐿f(\underline{W})\in\mathbb{R}^{L\times L}italic_f ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT as the matrix where the j0subscript𝑗0j_{0}italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th row is (f(W¯)j0)superscript𝑓subscript¯𝑊subscript𝑗0top(f(\underline{W})_{j_{0}})^{\top}( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT.

Definition G.4.

For every i0[d]subscript𝑖0delimited-[]𝑑i_{0}\in[d]italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_d ], define the function h(W¯OV)i0:d2L:subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0superscriptsuperscript𝑑2superscript𝐿h(\underline{W}_{OV})_{i_{0}}:\mathbb{R}^{d^{2}}\rightarrow\mathbb{R}^{L}italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT by

h(W¯OV)i0:=A3L×d(WOV),i0d×1.assignsubscriptsubscript¯𝑊𝑂𝑉subscript𝑖0subscriptsuperscriptsubscript𝐴3top𝐿𝑑subscriptsubscriptsuperscriptsubscript𝑊𝑂𝑉topsubscript𝑖0𝑑1\displaystyle h(\underline{W}_{OV})_{i_{0}}:=\underbrace{A_{3}^{\top}}_{L% \times d}\underbrace{(W_{OV}^{\top})_{*,i_{0}}}_{d\times 1}.italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT := under⏟ start_ARG italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_d end_POSTSUBSCRIPT under⏟ start_ARG ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT ∗ , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × 1 end_POSTSUBSCRIPT .

Here, WOVd×dsubscript𝑊𝑂𝑉superscript𝑑𝑑W_{OV}\in\mathbb{R}^{d\times d}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT denotes the matrix representation of W¯OVd2subscript¯𝑊𝑂𝑉superscriptsuperscript𝑑2\underline{W}_{OV}\in\mathbb{R}^{d^{2}}under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, and (WOV),i0subscriptsuperscriptsubscript𝑊𝑂𝑉topsubscript𝑖0(W_{OV})^{\top}_{*,i_{0}}( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ∗ , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT represents the i0subscript𝑖0i_{0}italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th column of WOVsuperscriptsubscript𝑊𝑂𝑉topW_{OV}^{\top}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Define h(W¯OV)L×dsubscript¯𝑊𝑂𝑉superscript𝐿𝑑h(\underline{W}_{OV})\in\mathbb{R}^{L\times d}italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT as the matrix where the i0subscript𝑖0i_{0}italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th column is h(W¯OV)i0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0h(\underline{W}_{OV})_{i_{0}}italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT.

Definition G.5.

For each j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], we denote f(W¯)j0L𝑓subscript¯𝑊subscript𝑗0superscript𝐿f(\underline{W})_{j_{0}}\in\mathbb{R}^{L}italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT as the normalized vector defined by Definition G.3. For each i0[d]subscript𝑖0delimited-[]𝑑i_{0}\in[d]italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_d ], h(W¯OV)i0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0h(\underline{W}_{OV})_{i_{0}}italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT is defined as per Definition G.4. For every pair (j0,i0)[L]×[d]subscript𝑗0subscript𝑖0delimited-[]𝐿delimited-[]𝑑(j_{0},i_{0})\in[L]\times[d]( italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∈ [ italic_L ] × [ italic_d ], define the function c(W¯)j0,i0:d2×d2:𝑐subscript¯𝑊subscript𝑗0subscript𝑖0superscriptsuperscript𝑑2superscriptsuperscript𝑑2c(\underline{W})_{j_{0},i_{0}}:\mathbb{R}^{d^{2}}\times\mathbb{R}^{d^{2}}% \rightarrow\mathbb{R}italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R by

c(W¯)j0,i0:=f(W¯)j0,h(W¯OV)i0Yj0,i0,assign𝑐subscript¯𝑊subscript𝑗0subscript𝑖0𝑓subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0subscriptsuperscript𝑌topsubscript𝑗0subscript𝑖0\displaystyle c(\underline{W})_{j_{0},i_{0}}:=\langle f(\underline{W})_{j_{0}}% ,h(\underline{W}_{OV})_{i_{0}}\rangle-Y^{\top}_{j_{0},i_{0}},italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT := ⟨ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ - italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ,

where (WOV)j0,i0subscriptsubscript𝑊𝑂𝑉subscript𝑗0subscript𝑖0(W_{OV})_{j_{0},i_{0}}( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the element at the (j0,i0)subscript𝑗0subscript𝑖0(j_{0},i_{0})( italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) position of the matrix WOVL×dsubscript𝑊𝑂𝑉superscript𝐿𝑑W_{OV}\in\mathbb{R}^{L\times d}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT. c()𝑐c(\cdot)italic_c ( ⋅ ) has matrix form

c(W¯)L×d=f(W¯)L×Lh(W¯OV)L×dYL×d.subscript𝑐¯𝑊𝐿𝑑subscript𝑓¯𝑊𝐿𝐿subscriptsubscript¯𝑊𝑂𝑉𝐿𝑑subscriptsuperscript𝑌top𝐿𝑑\displaystyle\underbrace{c(\underline{W})}_{L\times d}=\underbrace{f(% \underline{W})}_{L\times L}\underbrace{h(\underline{W}_{OV})}_{L\times d}-% \underbrace{Y^{\top}}_{L\times d}.under⏟ start_ARG italic_c ( under¯ start_ARG italic_W end_ARG ) end_ARG start_POSTSUBSCRIPT italic_L × italic_d end_POSTSUBSCRIPT = under⏟ start_ARG italic_f ( under¯ start_ARG italic_W end_ARG ) end_ARG start_POSTSUBSCRIPT italic_L × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT italic_L × italic_d end_POSTSUBSCRIPT - under⏟ start_ARG italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_d end_POSTSUBSCRIPT .

With the tensor trick (Section D.3), we compute the gradient dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG of the DiT loss as follows:

dg2dW¯=ddW¯[12j0=1Li0=1dcj0,i02(W¯)].derivative¯𝑊subscript𝑔2derivative¯𝑊delimited-[]12superscriptsubscriptsubscript𝑗01𝐿superscriptsubscriptsubscript𝑖01𝑑superscriptsubscript𝑐subscript𝑗0subscript𝑖02¯𝑊\displaystyle\derivative{g_{2}}{\underline{W}}=\derivative{\underline{W}}\left% [{\frac{1}{2}}\sum_{j_{0}=1}^{L}\sum_{i_{0}=1}^{d}c_{j_{0},i_{0}}^{2}(% \underline{W})\right].divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG = start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( under¯ start_ARG italic_W end_ARG ) ] . (G.1)

(G.1) presents a neat decomposition of dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG. Each term is easy enough to handle. Thus, we arrive the following lemma. Let Z[i,]𝑍𝑖Z[i,\cdot]italic_Z [ italic_i , ⋅ ] and Z[,j]𝑍𝑗Z[\cdot,j]italic_Z [ ⋅ , italic_j ] be the i𝑖iitalic_i-th row and j𝑗jitalic_j-th column of matrix Z𝑍Zitalic_Z.

Lemma G.1 (Low-Rank Decomposition of DiT Gradient).

Let matrix A1,A2,A3,W,WOV,Ysubscript𝐴1subscript𝐴2subscript𝐴3𝑊subscript𝑊𝑂𝑉𝑌A_{1},A_{2},A_{3},W,W_{OV},Yitalic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_W , italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT , italic_Y and loss function \mathcal{L}caligraphic_L follow Definition 4.1, and 𝖠A1A2𝖠tensor-productsuperscriptsubscript𝐴1topsuperscriptsubscript𝐴2top\operatorname{\mathsf{A}}\coloneqq A_{1}^{\top}\otimes A_{2}^{\top}sansserif_A ≔ italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⊗ italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. It holds

dg2dW¯=j0=1Li0=1dc(W¯)j0,i0𝖠j0(diag(f(W¯)j)(II)f(W¯)j0f(W¯)j0(III))(I)h(W¯OV)i0.derivative¯𝑊subscript𝑔2superscriptsubscriptsubscript𝑗01𝐿superscriptsubscriptsubscript𝑖01𝑑𝑐subscript¯𝑊subscript𝑗0subscript𝑖0superscriptsubscript𝖠subscript𝑗0topsubscriptsuperscriptdiag𝑓subscript¯𝑊𝑗𝐼𝐼superscript𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0top𝐼𝐼𝐼𝐼subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle\derivative{g_{2}}{\underline{W}}=\sum_{j_{0}=1}^{L}\sum_{i_{0}=1% }^{d}c(\underline{W})_{j_{0},i_{0}}\operatorname{\mathsf{A}}_{j_{0}}^{\top}% \underbrace{\Big{(}\overbrace{\mathop{\rm{diag}}\left(f(\underline{W})_{j}% \right)}^{(II)}-\overbrace{f(\underline{W})_{j_{0}}f(\underline{W})_{j_{0}}^{% \top}}^{(III)}\Big{)}}_{(I)}h(\underline{W}_{OV})_{i_{0}}.divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG = ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT under⏟ start_ARG ( over⏞ start_ARG roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG start_POSTSUPERSCRIPT ( italic_I italic_I ) end_POSTSUPERSCRIPT - over⏞ start_ARG italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUPERSCRIPT ( italic_I italic_I italic_I ) end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT ( italic_I ) end_POSTSUBSCRIPT italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT . (G.2)
Proof.

Let Z[i,]𝑍𝑖Z[i,\cdot]italic_Z [ italic_i , ⋅ ] and Z[,j]𝑍𝑗Z[\cdot,j]italic_Z [ ⋅ , italic_j ] be the i𝑖iitalic_i-th row and j𝑗jitalic_j-th column of matrix Z𝑍Zitalic_Z.

With DiT loss Definition 4.1, we have

dg2dW¯derivative¯𝑊subscript𝑔2\displaystyle\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG =12j0=1Li=1dddW¯cj0,i02(W¯)absent12superscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0¯𝑊\displaystyle={\frac{1}{2}}\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{% \underline{W}}c^{2}_{j_{0},i_{0}}(\underline{W})= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG )
=j0=1Li=1dddW¯cj0,i02c(W¯)j0,i0dc(W¯)j0,i0dW¯i0absentsuperscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0derivativesubscript¯𝑊subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0\displaystyle=\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{\underline{W}}c^{2}_% {j_{0},i_{0}}c(\underline{W})_{j_{0},i_{0}}\cdot\derivative{c(\underline{W})_{% j_{0},i_{0}}}{\underline{W}_{i_{0}}}= ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ divide start_ARG roman_d start_ARG italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG
=j0=1Li=1dddW¯cj0,i02c(W¯)j0,i0ddf(W¯)j0,h(W¯OV)i0W¯i0absentsuperscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0derivative𝑓subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0subscript¯𝑊subscript𝑖0\displaystyle=\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{\underline{W}}c^{2}_% {j_{0},i_{0}}c(\underline{W})_{j_{0},i_{0}}\cdot\derivative{\left\langle f(% \underline{W})_{j_{0}},h(\underline{W}_{OV})_{i_{0}}\right\rangle}{\underline{% W}_{i_{0}}}= ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG ⟨ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT (By Definition G.5)
=j0=1Li=1dddW¯cj0,i02c(W¯)j0,i0ddf(W¯)j0W¯i,h(W¯OV)i0absentsuperscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0derivative𝑓subscript¯𝑊subscript𝑗0subscript¯𝑊𝑖subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle=\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{\underline{W}}c^{2}_% {j_{0},i_{0}}c(\underline{W})_{j_{0},i_{0}}\cdot\left\langle\derivative{f(% \underline{W})_{j_{0}}}{\underline{W}_{i}},h(\underline{W}_{OV})_{i_{0}}\right\rangle= ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ ⟨ start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩
=j0=1Li=1dddW¯cj0,i02c(W¯)j0,i0ddα1(W¯)j0u(W¯)j0W¯i,h(W¯OV)i0absentsuperscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0derivativesuperscript𝛼1subscript¯𝑊subscript𝑗0𝑢subscript¯𝑊subscript𝑗0subscript¯𝑊𝑖subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle=\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{\underline{W}}c^{2}_% {j_{0},i_{0}}c(\underline{W})_{j_{0},i_{0}}\cdot\left\langle\derivative{\alpha% ^{-1}(\underline{W})_{j_{0}}u(\underline{W})_{j_{0}}}{\underline{W}_{i}},h(% \underline{W}_{OV})_{i_{0}}\right\rangle= ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ ⟨ start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG italic_α start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ (By Definition G.3)
=j0=1Li=1dddW¯cj0,i02c(W¯)j0,i0α(W¯)j01ddu(W¯)j0W¯i0+ddα(W¯)j01W¯i0u(W¯)j0,h(W¯OV)i0absentsuperscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0𝛼superscriptsubscript¯𝑊subscript𝑗01derivative𝑢subscript¯𝑊subscript𝑗0subscript¯𝑊subscript𝑖0derivative𝛼superscriptsubscript¯𝑊subscript𝑗01subscript¯𝑊subscript𝑖0𝑢subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle=\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{\underline{W}}c^{2}_% {j_{0},i_{0}}c(\underline{W})_{j_{0},i_{0}}\cdot\left\langle\alpha(\underline{% W})_{j_{0}}^{-1}\cdot\derivative{u(\underline{W})_{j_{0}}}{\underline{W}_{i_{0% }}}+\derivative{\alpha(\underline{W})_{j_{0}}^{-1}}{\underline{W}_{i_{0}}}% \cdot u(\underline{W})_{j_{0}},h(\underline{W}_{OV})_{i_{0}}\right\rangle= ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ ⟨ italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩
=j0=1Li=1dddW¯cj0,i02c(W¯)j0,i0α(W¯)j01du(W¯)j0dW¯i0α(W¯)j02dα(W¯)j0dW¯i0u(W¯)j0,h(W¯OV)i0.absentsuperscriptsubscriptsubscript𝑗01𝐿superscriptsubscript𝑖1𝑑derivative¯𝑊subscriptsuperscript𝑐2subscript𝑗0subscript𝑖0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0𝛼superscriptsubscript¯𝑊subscript𝑗01derivativesubscript¯𝑊subscript𝑖0𝑢subscript¯𝑊subscript𝑗0𝛼superscriptsubscript¯𝑊subscript𝑗02derivativesubscript¯𝑊subscript𝑖0𝛼subscript¯𝑊subscript𝑗0𝑢subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle=\sum_{j_{0}=1}^{L}\sum_{i=1}^{d}\derivative{\underline{W}}c^{2}_% {j_{0},i_{0}}c(\underline{W})_{j_{0},i_{0}}\cdot\left\langle\alpha(\underline{% W})_{j_{0}}^{-1}\cdot\derivative{u(\underline{W})_{j_{0}}}{\underline{W}_{i_{0% }}}-\alpha(\underline{W})_{j_{0}}^{-2}\derivative{\alpha(\underline{W})_{j_{0}% }}{\underline{W}_{i_{0}}}\cdot u(\underline{W})_{j_{0}},h(\underline{W}_{OV})_% {i_{0}}\right\rangle.= ∑ start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ ⟨ italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG roman_d start_ARG italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG - italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT divide start_ARG roman_d start_ARG italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG ⋅ italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ . (By chain rule)

For each j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], we have

dd(𝖠j0W¯)W¯i0=𝖠j0ddW¯W¯i0=(𝖠j0)[,i].derivativesubscript𝖠subscript𝑗0¯𝑊subscript¯𝑊subscript𝑖0subscript𝖠subscript𝑗0derivative¯𝑊subscript¯𝑊subscript𝑖0subscript𝖠subscript𝑗0𝑖\displaystyle\derivative{\left(\operatorname{\mathsf{A}}_{j_{0}}\underline{W}% \right)}{\underline{W}_{i_{0}}}=\operatorname{\mathsf{A}}_{j_{0}}\cdot% \derivative{\underline{W}}{\underline{W}_{i_{0}}}=\left(\operatorname{\mathsf{% A}}_{j_{0}}\right)[\cdot,i].start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG ( sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG ) end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = ( sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) [ ⋅ , italic_i ] .

Therefore, for each j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], we have

ddu(W¯)j0W¯i0derivative𝑢subscript¯𝑊subscript𝑗0subscript¯𝑊subscript𝑖0\displaystyle\derivative{u(\underline{W})_{j_{0}}}{\underline{W}_{i_{0}}}start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT =ddexp(𝖠j0W¯)W¯i0absentderivativesubscript𝖠subscript𝑗0¯𝑊subscript¯𝑊subscript𝑖0\displaystyle=\derivative{\exp\left(\operatorname{\mathsf{A}}_{j_{0}}% \underline{W}\right)}{\underline{W}_{i_{0}}}= start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG roman_exp ( sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG ) end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT (By Definition G.1)
=exp(𝖠j0W¯)dd𝖠j0W¯W¯i0absentdirect-productsubscript𝖠subscript𝑗0¯𝑊derivativesubscript𝖠subscript𝑗0¯𝑊subscript¯𝑊subscript𝑖0\displaystyle=\exp\left(\operatorname{\mathsf{A}}_{j_{0}}\underline{W}\right)% \odot\derivative{\operatorname{\mathsf{A}}_{j_{0}}\underline{W}}{\underline{W}% _{i_{0}}}= roman_exp ( sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG ) ⊙ start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under¯ start_ARG italic_W end_ARG end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT (By entry-wise product rule)
=𝖠j0[,i]u(W¯)j0.absentdirect-productsubscript𝖠subscript𝑗0𝑖𝑢subscript¯𝑊subscript𝑗0\displaystyle=\operatorname{\mathsf{A}}_{j_{0}}[\cdot,i]\odot u(\underline{W})% _{j_{0}}.= sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] ⊙ italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT . (By Definition G.1 again)

Similarly,

dα(W¯)j0dW¯i0=derivativesubscript¯𝑊subscript𝑖0𝛼subscript¯𝑊subscript𝑗0absent\displaystyle\derivative{\alpha(\underline{W})_{j_{0}}}{\underline{W}_{i_{0}}}=divide start_ARG roman_d start_ARG italic_α ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG = ddu(W¯)j0,𝟙LW¯i0derivative𝑢subscript¯𝑊subscript𝑗0subscript1𝐿subscript¯𝑊subscript𝑖0\displaystyle\leavevmode\nobreak\ \derivative{\left\langle u(\underline{W})_{j% _{0}},\mathds{1}_{L}\right\rangle}{\underline{W}_{i_{0}}}start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG ⟨ italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ⟩ end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT (By Definition G.2)
=\displaystyle== 𝖠j0[,i]u(W¯)j0,𝟙Ldirect-productsubscript𝖠subscript𝑗0𝑖𝑢subscript¯𝑊subscript𝑗0subscript1𝐿\displaystyle\leavevmode\nobreak\ \left\langle\operatorname{\mathsf{A}}_{j_{0}% }[\cdot,i]\odot u(\underline{W})_{j_{0}},\mathds{1}_{L}\right\rangle⟨ sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] ⊙ italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ⟩ (By entry-wise product rule)
=\displaystyle== 𝖠j0[,i],u(W¯)j0.subscript𝖠subscript𝑗0𝑖𝑢subscript¯𝑊subscript𝑗0\displaystyle\leavevmode\nobreak\ \left\langle\operatorname{\mathsf{A}}_{j_{0}% }[\cdot,i],u(\underline{W})_{j_{0}}\right\rangle.⟨ sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] , italic_u ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ . (By Definition G.1 again)

Putting all together, we have

ddg2(W¯)j0,i0W¯i0derivativesubscript𝑔2subscript¯𝑊subscript𝑗0subscript𝑖0subscript¯𝑊subscript𝑖0\displaystyle\leavevmode\nobreak\ \derivative{g_{2}(\underline{W})_{j_{0},i_{0% }}}{\underline{W}_{i_{0}}}start_DIFFOP divide start_ARG roman_d end_ARG start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG end_ARG end_DIFFOP under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
=\displaystyle== [h(W¯OV)i0,𝖠j0[,i]f(W¯)j0h(W¯OV)i0,f(W¯)j0𝖠j0[,i],f(W¯)j0]c(W¯)j0,i0,delimited-[]subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0direct-productsubscript𝖠subscript𝑗0𝑖𝑓subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0𝑓subscript¯𝑊subscript𝑗0subscript𝖠subscript𝑗0𝑖𝑓subscript¯𝑊subscript𝑗0𝑐subscript¯𝑊subscript𝑗0subscript𝑖0\displaystyle\leavevmode\nobreak\ \left[\left\langle h(\underline{W}_{OV})_{i_% {0}},\operatorname{\mathsf{A}}_{j_{0}}[\cdot,i]\odot f(\underline{W})_{j_{0}}% \right\rangle-\left\langle h(\underline{W}_{OV})_{i_{0}},f(\underline{W})_{j_{% 0}}\right\rangle\cdot\left\langle\operatorname{\mathsf{A}}_{j_{0}}[\cdot,i],f(% \underline{W})_{j_{0}}\right\rangle\right]\cdot c(\underline{W})_{j_{0},i_{0}},[ ⟨ italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] ⊙ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ - ⟨ italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ ⋅ ⟨ sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] , italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ ] ⋅ italic_c ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ,

where

h(W¯OV)i0,𝖠j0[,i]f(W¯)j0h(W¯OV)i0,f(W¯)j0𝖠j0[,i],f(W¯)j0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0direct-productsubscript𝖠subscript𝑗0𝑖𝑓subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0𝑓subscript¯𝑊subscript𝑗0subscript𝖠subscript𝑗0𝑖𝑓subscript¯𝑊subscript𝑗0\displaystyle\left\langle h(\underline{W}_{OV})_{i_{0}},\operatorname{\mathsf{% A}}_{j_{0}}[\cdot,i]\odot f(\underline{W})_{j_{0}}\right\rangle-\left\langle h% (\underline{W}_{OV})_{i_{0}},f(\underline{W})_{j_{0}}\right\rangle\cdot\left% \langle\operatorname{\mathsf{A}}_{j_{0}}[\cdot,i],f(\underline{W})_{j_{0}}\right\rangle⟨ italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] ⊙ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ - ⟨ italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ ⋅ ⟨ sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ⋅ , italic_i ] , italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩
=\displaystyle== 𝖠j0(diag(f(W¯)j0)f(W¯)j0f(W¯)j0)h(W¯OV)i0.superscriptsubscript𝖠subscript𝑗0topdiag𝑓subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0topsubscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle\leavevmode\nobreak\ \operatorname{\mathsf{A}}_{j_{0}}^{\top}% \left(\operatorname{\mathop{\rm{diag}}}\left(f({\underline{W}})_{j_{0}}\right)% -f({\underline{W}})_{j_{0}}f({\underline{W}})_{j_{0}}^{\top}\right)h(% \underline{W}_{OV})_{i_{0}}.sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT .

This completes the proof. ∎

Observe (G.2) carefully. We see that (I) is diagonal and (II) is low-rank. This provides a hint for algorithmic speedup through low-rank approximation: If we approximate the other parts with low-rank approximation and carefully match the multiplication dimensions, we might formulate the computation of dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG as a chained low-rank approximation.

Surprisingly, such approach makes computing (G.2) as fast as in almost-linear time. To proceed, we further decompose (G.2) according to the chain-rule in the next lemma, and then conduct the approximation term-by-term.

To facilitate our proof, it’s convenient to introduce the following notations.

Definition G.6 (q()𝑞q(\cdot)italic_q ( ⋅ )).

Define c(W¯)L×d𝑐¯𝑊superscript𝐿𝑑c(\underline{W})\in\mathbb{R}^{L\times d}italic_c ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT as specified in Definition G.5 and h(W¯OV)L×dsubscript¯𝑊𝑂𝑉superscript𝐿𝑑h(\underline{W}_{OV})\in\mathbb{R}^{L\times d}italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT as described in Definition G.4. Define q(W¯)L×L𝑞¯𝑊superscript𝐿𝐿q(\underline{W})\in\mathbb{R}^{L\times L}italic_q ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT by

q(W¯):=c(W¯)L×dh(W¯OV)d×L.assign𝑞¯𝑊subscript𝑐¯𝑊𝐿𝑑subscriptsuperscriptsubscript¯𝑊𝑂𝑉top𝑑𝐿\displaystyle q(\underline{W}):=\underbrace{c(\underline{W})}_{L\times d}% \underbrace{h(\underline{W}_{OV})^{\top}}_{d\times L}.italic_q ( under¯ start_ARG italic_W end_ARG ) := under⏟ start_ARG italic_c ( under¯ start_ARG italic_W end_ARG ) end_ARG start_POSTSUBSCRIPT italic_L × italic_d end_POSTSUBSCRIPT under⏟ start_ARG italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT .

In addition, q(W¯)j0𝑞superscriptsubscript¯𝑊subscript𝑗0topq(\underline{W})_{j_{0}}^{\top}italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT denotes the j0subscript𝑗0j_{0}italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th row of q(W¯)𝑞¯𝑊q(\underline{W})italic_q ( under¯ start_ARG italic_W end_ARG ), transposed, making it an L×1𝐿1L\times 1italic_L × 1 vector.

Definition G.7 (p()𝑝p(\cdot)italic_p ( ⋅ ),p1()subscript𝑝1p_{1}(\cdot)italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ), p2()subscript𝑝2p_{2}(\cdot)italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ )).

For each index j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], we define p(W¯)j0n𝑝subscript¯𝑊subscript𝑗0superscript𝑛p(\underline{W})_{j_{0}}\in\mathbb{R}^{n}italic_p ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT as follows:

p(W¯)j0:=(diag(f(W¯)j0)f(W¯)j0f(W¯)j0)q(W¯)j0.assign𝑝subscript¯𝑊subscript𝑗0diag𝑓subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0top𝑞subscript¯𝑊subscript𝑗0\displaystyle p(\underline{W})_{j_{0}}:=\left(\mathop{\rm{diag}}(f(\underline{% W})_{j_{0}})-f(\underline{W})_{j_{0}}f(\underline{W})_{j_{0}}^{\top}\right)q(% \underline{W})_{j_{0}}.italic_p ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT := ( roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT .

We define p(W¯)L×L𝑝¯𝑊superscript𝐿𝐿p(\underline{W})\in\mathbb{R}^{L\times L}italic_p ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT such that p(W¯)j0𝑝superscriptsubscript¯𝑊subscript𝑗0topp(\underline{W})_{j_{0}}^{\top}italic_p ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT forms the j0subscript𝑗0j_{0}italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th row of p(W¯)𝑝¯𝑊p(\underline{W})italic_p ( under¯ start_ARG italic_W end_ARG ). In addition, for every index j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ], we define p1(W¯)j0,p2(W¯)j0Lsubscript𝑝1subscript¯𝑊subscript𝑗0subscript𝑝2subscript¯𝑊subscript𝑗0superscript𝐿p_{1}(\underline{W})_{j_{0}},p_{2}(\underline{W})_{j_{0}}\in\mathbb{R}^{L}italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT as

p1(W¯)j0diag(f(W¯)j0)q(W¯)j0,p2(W¯)j0f(W¯)j0f(W¯)j0q(W¯)j0,formulae-sequencesubscript𝑝1subscript¯𝑊subscript𝑗0diag𝑓subscript¯𝑊subscript𝑗0𝑞subscript¯𝑊subscript𝑗0subscript𝑝2subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0top𝑞subscript¯𝑊subscript𝑗0\displaystyle p_{1}(\underline{W})_{j_{0}}\coloneqq\mathop{\rm{diag}}\left(f% \left(\underline{W}\right)_{j_{0}}\right)q(\underline{W})_{j_{0}},\quad p_{2}(% \underline{W})_{j_{0}}\coloneqq f\left(\underline{W}\right)_{j_{0}}f\left(% \underline{W}\right)_{j_{0}}^{\top}q(\underline{W})_{j_{0}},italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≔ roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≔ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ,

such that p(W¯)=p1(W¯)p2(W¯)𝑝¯𝑊subscript𝑝1¯𝑊subscript𝑝2¯𝑊p(\underline{W})=p_{1}(\underline{W})-p_{2}(\underline{W})italic_p ( under¯ start_ARG italic_W end_ARG ) = italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ).

p()𝑝p(\cdot)italic_p ( ⋅ ) allows us to express dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG in a neat form:

Lemma G.2.

Define the functions f(W¯)L×L𝑓¯𝑊superscript𝐿𝐿f(\underline{W})\in\mathbb{R}^{L\times L}italic_f ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT, c(W¯)d×L𝑐¯𝑊superscript𝑑𝐿c(\underline{W})\in\mathbb{R}^{d\times L}italic_c ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, h(W¯OV)d×Lsubscript¯𝑊𝑂𝑉superscript𝑑𝐿h(\underline{W}_{OV})\in\mathbb{R}^{d\times L}italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, q(W¯)L×L𝑞¯𝑊superscript𝐿𝐿q(\underline{W})\in\mathbb{R}^{L\times L}italic_q ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT, and p(W¯)L×L𝑝¯𝑊superscript𝐿𝐿p(\underline{W})\in\mathbb{R}^{L\times L}italic_p ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT as specified in Definitions G.3, G.5, G.4, G.6 and G.7, respectively. Let A1,A2d×Lsubscript𝐴1subscript𝐴2superscript𝑑𝐿A_{1},A_{2}\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT be two given matrices, and define 𝖠=A1A2𝖠tensor-productsuperscriptsubscript𝐴1topsuperscriptsubscript𝐴2top\operatorname{\mathsf{A}}=A_{1}^{\top}\otimes A_{2}^{\top}sansserif_A = italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⊗ italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Define g2subscript𝑔2g_{2}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT according to (O1), and let g2(W¯)j0,i0subscript𝑔2subscript¯𝑊subscript𝑗0subscript𝑖0g_{2}(\underline{W})_{j_{0},i_{0}}italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT be as described in (G.1). It holds

dg2dW¯=vec(A1p(W¯)A2).derivative¯𝑊subscript𝑔2vecsubscript𝐴1𝑝¯𝑊superscriptsubscript𝐴2top\displaystyle\derivative{g_{2}}{\underline{W}}=\operatorname{vec}\left(A_{1}p(% \underline{W})A_{2}^{\top}\right).divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG = roman_vec ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (G.3)
Proof.

By definitions, (G.1) gives

d(g2)j0,i0dW¯i0dsubscriptsubscript𝑔2subscript𝑗0subscript𝑖0dsubscript¯𝑊subscript𝑖0\displaystyle\leavevmode\nobreak\ \frac{\mathrm{d}(g_{2})_{j_{0},i_{0}}}{% \mathrm{d}\underline{W}_{i_{0}}}divide start_ARG roman_d ( italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG roman_d under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG (G.4)
=\displaystyle== cj0,i0(f(W¯)j0𝖠j0,i0,h(W¯OV)i0=𝖠j0,idiag(f(W¯)j0)h(W¯OV)i0f(W¯)j0,h(W¯OV)i0f(W¯)j0,𝖠j0,i0)=𝖠j0,if(W¯)j0f(W¯)j0h(W¯OV)i0.\displaystyle\leavevmode\nobreak\ c_{j_{0},i_{0}}\cdot(\underbrace{\langle f(% \underline{W})_{j_{0}}\odot\operatorname{\mathsf{A}}_{j_{0},i_{0}},h(% \underline{W}_{OV})_{i_{0}}\rangle}_{=\operatorname{\mathsf{A}}_{j_{0},i}^{% \top}\mathop{\rm{diag}}(f(\underline{W})_{j_{0}})h(\underline{W}_{OV})_{i_{0}}% }-\underbrace{\langle f(\underline{W})_{j_{0}},h(\underline{W}_{OV})_{i_{0}}% \rangle\cdot\langle f(\underline{W})_{j_{0}},\operatorname{\mathsf{A}}_{j_{0},% i_{0}}\rangle)}_{=\operatorname{\mathsf{A}}_{j_{0},i}^{\top}f(\underline{W})_{% j_{0}}f(\underline{W})_{j_{0}}^{\top}h(\underline{W}_{OV})_{i_{0}}}.italic_c start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ ( under⏟ start_ARG ⟨ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⊙ sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ end_ARG start_POSTSUBSCRIPT = sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT - under⏟ start_ARG ⟨ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ ⋅ ⟨ italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ ) end_ARG start_POSTSUBSCRIPT = sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT . (By ab,c=adiag(b)cexpectationdirect-product𝑎𝑏𝑐superscript𝑎topdiag𝑏𝑐\Braket{a\odot b,c}=a^{\top}\mathop{\rm{diag}}(b)c⟨ start_ARG italic_a ⊙ italic_b , italic_c end_ARG ⟩ = italic_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_diag ( italic_b ) italic_c for a,b,cL𝑎𝑏𝑐superscript𝐿a,b,c\in\mathbb{R}^{L}italic_a , italic_b , italic_c ∈ blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT)

Therefore, (G.4) becomes

d(g2)j0,i0dW¯i0=dsubscriptsubscript𝑔2subscript𝑗0subscript𝑖0dsubscript¯𝑊subscript𝑖0absent\displaystyle\frac{\mathrm{d}(g_{2})_{j_{0},i_{0}}}{\mathrm{d}\underline{W}_{i% _{0}}}=divide start_ARG roman_d ( italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG roman_d under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG = cj0,i0(𝖠j0,idiag(f(W¯)j0)h(W¯OV)i0𝖠j0,if(W¯)j0f(W¯)j0h(W¯OV)i0)subscript𝑐subscript𝑗0subscript𝑖0superscriptsubscript𝖠subscript𝑗0𝑖topdiag𝑓subscript¯𝑊subscript𝑗0subscriptsubscript¯𝑊𝑂𝑉subscript𝑖0superscriptsubscript𝖠subscript𝑗0𝑖top𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0topsubscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle\leavevmode\nobreak\ c_{j_{0},i_{0}}\cdot(\operatorname{\mathsf{A% }}_{j_{0},i}^{\top}\mathop{\rm{diag}}(f(\underline{W})_{j_{0}})h(\underline{W}% _{OV})_{i_{0}}-\operatorname{\mathsf{A}}_{j_{0},i}^{\top}f(\underline{W})_{j_{% 0}}f(\underline{W})_{j_{0}}^{\top}h(\underline{W}_{OV})_{i_{0}})italic_c start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ ( sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
=\displaystyle== cj0,i0𝖠j0,i(diag(f(W¯)j0)f(W¯)j0f(W¯)j0)h(W¯OV)i0.subscript𝑐subscript𝑗0subscript𝑖0superscriptsubscript𝖠subscript𝑗0𝑖topdiag𝑓subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0topsubscriptsubscript¯𝑊𝑂𝑉subscript𝑖0\displaystyle\leavevmode\nobreak\ c_{j_{0},i_{0}}\cdot\operatorname{\mathsf{A}% }_{j_{0},i}^{\top}(\mathop{\rm{diag}}(f(\underline{W})_{j_{0}})-f(\underline{W% })_{j_{0}}f(\underline{W})_{j_{0}}^{\top})h(\underline{W}_{OV})_{i_{0}}.italic_c start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⋅ sansserif_A start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_diag ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT . (G.5)

Then, by definitions of q(),p()𝑞𝑝q(\cdot),p(\cdot)italic_q ( ⋅ ) , italic_p ( ⋅ ), we complete the proof. ∎

G.1.2 Low-Rank Approximations of Building Blocks I

The definitions of p𝑝pitalic_p, p1subscript𝑝1p_{1}italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, p2subscript𝑝2p_{2}italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and Lemma G.2 show that the DiT training gradient dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG involves entry-wise products of f𝑓fitalic_f, q𝑞qitalic_q, and c𝑐citalic_c. Therefore, if we approximate these with inner-dimension-matched low-rank approximations, computing dg2dW¯derivative¯𝑊subscript𝑔2\derivative{g_{2}}{\underline{W}}divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG itself becomes a low-rank approximation. In the following sections, we present low-rank approximations for f𝑓fitalic_f, q𝑞qitalic_q, and c𝑐citalic_c.

Lemma G.3 (Approximate f()𝑓f(\cdot)italic_f ( ⋅ ), Modified from (Alman and Song, 2023)).

Let Γ=o(logL)Γ𝑜𝐿\Gamma=o(\sqrt{\log L})roman_Γ = italic_o ( square-root start_ARG roman_log italic_L end_ARG ) and k1=Lo(1)subscript𝑘1superscript𝐿𝑜1k_{1}=L^{o(1)}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT. Let A1,A2,d×LA_{1},A_{2},\in\mathbb{R}^{d\times L}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_L end_POSTSUPERSCRIPT, Wd×d𝑊superscript𝑑𝑑W\in\mathbb{R}^{d\times d}italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT and f(W¯)=D1exp(A1𝐗A2)𝑓¯𝑊superscript𝐷1superscriptsubscript𝐴1top𝐗subscript𝐴2f(\underline{W})=D^{-1}\exp(A_{1}^{\top}\mathbf{X}A_{2})italic_f ( under¯ start_ARG italic_W end_ARG ) = italic_D start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_exp ( start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_X italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) with D=diag(exp(A1WA2)𝟙L)𝐷diagsuperscriptsubscript𝐴1top𝑊subscript𝐴2subscript1𝐿D=\mathop{\rm{diag}}\left(\exp\left(A_{1}^{\top}WA_{2}\right){\mathds{1}_{L}}\right)italic_D = roman_diag ( roman_exp ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) blackboard_1 start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) follows Definitions G.1, G.2, G.5 and G.3. If max(A1WmaxΓ\max\big{(}\norm{A_{1}^{\top}W}_{\max}\leq\Gammaroman_max ( ∥ start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ roman_Γ,A2max)Γ\norm{A_{2}}_{\max}\big{)}\leq\Gamma∥ start_ARG italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ≤ roman_Γ, then there exist two matrices U1,V1L×k1subscript𝑈1subscript𝑉1superscript𝐿subscript𝑘1U_{1},V_{1}\in\mathbb{R}^{L\times k_{1}}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that U1V1f(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈1superscriptsubscript𝑉1top𝑓¯𝑊italic-ϵpoly𝐿\norm{U_{1}V_{1}^{\top}-f(\underline{W})}_{\max}\leq\epsilon/\mathrm{poly}(L)∥ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ). In addition, it takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct U1subscript𝑈1U_{1}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and V1subscript𝑉1V_{1}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Proof.

By (Alman and Song, 2023, Theorem 3), we complete the proof. ∎

Lemma G.4 (Approximate c()𝑐c(\cdot)italic_c ( ⋅ )).

Assume all numerical values are in O(logL)𝑂𝐿O(\log L)italic_O ( roman_log italic_L ) bits. Let d=O(logL)𝑑𝑂𝐿d=O(\log L)italic_d = italic_O ( roman_log italic_L ) and c(W¯)L×d𝑐¯𝑊superscript𝐿𝑑c(\underline{W})\in\mathbb{R}^{L\times d}italic_c ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT follows Definition G.5. There exist two matrices U1,V1L×k1subscript𝑈1subscript𝑉1superscript𝐿subscript𝑘1U_{1},V_{1}\in\mathbb{R}^{L\times k_{1}}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that U1V1h(WOV)Yc(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉superscript𝑌top𝑐¯𝑊italic-ϵpoly𝐿\left\|U_{1}V_{1}^{\top}h(W_{OV})-Y^{\top}-c(\underline{W})\right\|_{\max}\leq% \epsilon/\mathrm{poly}(L)∥ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_c ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ).

Proof of Lemma G.4.
U1V1h(WOV)Yc(W¯)maxsubscriptnormsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉superscript𝑌top𝑐¯𝑊\displaystyle\left\|U_{1}V_{1}^{\top}h(W_{OV})-Y^{\top}-c(\underline{W})\right% \|_{\max}∥ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_c ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT =U1V1h(WOV)Y(f(W¯)h(WOV)Y)maxabsentsubscriptnormsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉superscript𝑌top𝑓¯𝑊subscript𝑊𝑂𝑉superscript𝑌top\displaystyle=\left\|U_{1}V_{1}^{\top}h(W_{OV})-Y^{\top}-(f(\underline{W})h(W_% {OV})-Y^{\top})\right\|_{\max}= ∥ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - ( italic_f ( under¯ start_ARG italic_W end_ARG ) italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT (By Definition G.5)
=[U1V1f(W¯)]h(WOV)maxabsentsubscriptnormdelimited-[]subscript𝑈1superscriptsubscript𝑉1top𝑓¯𝑊subscript𝑊𝑂𝑉\displaystyle=\left\|\left[U_{1}V_{1}^{\top}-f(\underline{W})\right]h(W_{OV})% \right\|_{\max}= ∥ [ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) ] italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
ϵ/poly(L).absentitalic-ϵpoly𝐿\displaystyle\leq\epsilon/\mathrm{poly}(L).≤ italic_ϵ / roman_poly ( italic_L ) . (By (Alman and Song, 2023, Theorem 3))

Lemma G.5 (Approximate q()𝑞q(\cdot)italic_q ( ⋅ )).

Let k2=Lo(1)subscript𝑘2superscript𝐿𝑜1k_{2}=L^{o(1)}italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT, c()L×d𝑐superscript𝐿𝑑c(\cdot)\in\mathbb{R}^{L\times d}italic_c ( ⋅ ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT follow Definition G.5 and let q(W¯)c(W¯)h(W¯OV)𝖳L×L𝑞¯𝑊𝑐¯𝑊superscriptsubscript¯𝑊𝑂𝑉𝖳superscript𝐿𝐿q(\underline{W})\coloneqq c(\underline{W})h(\underline{W}_{OV})^{\mathsf{T}}% \in\mathbb{R}^{L\times L}italic_q ( under¯ start_ARG italic_W end_ARG ) ≔ italic_c ( under¯ start_ARG italic_W end_ARG ) italic_h ( under¯ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT (follow Definition G.6). There exist two matrices U2,V2L×k2subscript𝑈2subscript𝑉2superscript𝐿subscript𝑘2U_{2},V_{2}\in\mathbb{R}^{L\times k_{2}}italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that U2V2q(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈2superscriptsubscript𝑉2top𝑞¯𝑊italic-ϵpoly𝐿\left\|U_{2}V_{2}^{\top}-q(\underline{W})\right\|_{\max}\leq\epsilon/\mathrm{% poly}(L)∥ italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ). In addition, it takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct U2,V2subscript𝑈2subscript𝑉2U_{2},V_{2}italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

Proof of Lemma G.5.

Our proof is built on (Alman and Song, 2023, Lemma D.3).

Let q~()~𝑞\widetilde{q}(\cdot)over~ start_ARG italic_q end_ARG ( ⋅ ) denote an approximation to q()𝑞q(\cdot)italic_q ( ⋅ ).

By Lemma G.4, U1V1h(WOV)Ysubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉𝑌U_{1}V_{1}^{\top}h(W_{OV})-Yitalic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y approximates c(W¯)𝑐¯𝑊c(\underline{W})italic_c ( under¯ start_ARG italic_W end_ARG ) up to accuracy ϵ=1/poly(L)italic-ϵ1poly𝐿\epsilon=1/\mathrm{poly}(L)italic_ϵ = 1 / roman_poly ( italic_L ).

Thus, by setting q~(W¯)=h(WOV)(U1V1h(WOV)Y)~𝑞¯𝑊subscript𝑊𝑂𝑉superscriptsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉𝑌top\widetilde{q}(\underline{W})=h(W_{OV})\left(U_{1}V_{1}^{\top}h(W_{OV})-Y\right% )^{\top}over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) = italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, we find a low-rank form for q~()~𝑞\widetilde{q}(\cdot)over~ start_ARG italic_q end_ARG ( ⋅ ):

q~(W¯)=h(WOV)(h(WOV))V1U1h(WOV)Y,~𝑞¯𝑊subscript𝑊𝑂𝑉superscriptsubscript𝑊𝑂𝑉topsubscript𝑉1superscriptsubscript𝑈1topsubscript𝑊𝑂𝑉superscript𝑌top\displaystyle\widetilde{q}(\underline{W})=h(W_{OV})\left(h(W_{OV})\right)^{% \top}V_{1}U_{1}^{\top}-h(W_{OV})Y^{\top},over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) = italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ( italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ,

such that

q~(W¯)q(W¯)maxsubscriptnorm~𝑞¯𝑊𝑞¯𝑊\displaystyle\|\widetilde{q}(\underline{W})-q(\underline{W})\|_{\max}∥ over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) - italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT =h(WOV)(U1V1h(WOV)Y)h(WOV)Ymaxabsentsubscriptnormsubscript𝑊𝑂𝑉superscriptsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉𝑌topsubscript𝑊𝑂𝑉superscript𝑌top\displaystyle=\left\|h(W_{OV})\left(U_{1}V_{1}^{\top}h(W_{OV})-Y\right)^{\top}% -h(W_{OV})Y^{\top}\right\|_{\max}= ∥ italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
dh(WOV)maxU1V1h(WOV)Yc(W¯)maxabsent𝑑subscriptnormsubscript𝑊𝑂𝑉subscriptnormsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑊𝑂𝑉𝑌𝑐¯𝑊\displaystyle\leq d\left\|h(W_{OV})\right\|_{\max}\left\|U_{1}V_{1}^{\top}h(W_% {OV})-Y-c(\underline{W})\right\|_{\max}≤ italic_d ∥ italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ∥ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) - italic_Y - italic_c ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
ϵ/poly(L).absentitalic-ϵpoly𝐿\displaystyle\leq\epsilon/\mathrm{poly}(L).≤ italic_ϵ / roman_poly ( italic_L ) .

By k1,d=Lo(1)subscript𝑘1𝑑superscript𝐿𝑜1k_{1},d=L^{o(1)}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_d = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT, compute (h(WOV))d×LV1L×k1U1k1×Lsubscriptsuperscriptsubscript𝑊𝑂𝑉top𝑑𝐿subscriptsubscript𝑉1𝐿subscript𝑘1subscriptsuperscriptsubscript𝑈1topsubscript𝑘1𝐿\underbrace{\left(h(W_{OV})\right)^{\top}}_{{d\times L}}\underbrace{V_{1}}_{L% \times k_{1}}\underbrace{U_{1}^{\top}}_{k_{1}\times L}under⏟ start_ARG ( italic_h ( italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_L end_POSTSUBSCRIPT takes only L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time. This completes the proof. ∎

G.1.3 Low-Rank Approximations of Building Blocks II

Now, we use the low-rank approximations of f,q,c𝑓𝑞𝑐f,q,citalic_f , italic_q , italic_c to construct low-rank approximations for p1(),p2(),p()subscript𝑝1subscript𝑝2𝑝p_{1}(\cdot),p_{2}(\cdot),p(\cdot)italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ) , italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ ) , italic_p ( ⋅ ).

Lemma G.6 (Approximate p1()subscript𝑝1p_{1}(\cdot)italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ )).

Let k1,k2=Lo(1)subscript𝑘1subscript𝑘2superscript𝐿𝑜1k_{1},k_{2}=L^{o(1)}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT. Suppose U1,V1L×k1subscript𝑈1subscript𝑉1superscript𝐿subscript𝑘1U_{1},V_{1}\in\mathbb{R}^{L\times k_{1}}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT approximates f(W¯)L×L𝑓¯𝑊superscript𝐿𝐿f(\underline{W})\in\mathbb{R}^{L\times L}italic_f ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT such that U1V1f(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈1superscriptsubscript𝑉1top𝑓¯𝑊italic-ϵpoly𝐿\left\|U_{1}V_{1}^{\top}-f(\underline{W})\right\|_{\max}\leq\epsilon/\mathrm{% poly}(L)∥ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ), and U2,V2L×k2subscript𝑈2subscript𝑉2superscript𝐿subscript𝑘2U_{2},V_{2}\in\mathbb{R}^{L\times k_{2}}italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT approximates the q(W¯)L×L𝑞¯𝑊superscript𝐿𝐿q(\underline{W})\in\mathbb{R}^{L\times L}italic_q ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT such that U2V2q(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈2superscriptsubscript𝑉2top𝑞¯𝑊italic-ϵpoly𝐿\left\|U_{2}V_{2}^{\top}-q(\underline{W})\right\|_{\max}\leq\epsilon/\mathrm{% poly}(L)∥ italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ). Then there exist two matrices U3,V3L×k3subscript𝑈3subscript𝑉3superscript𝐿subscript𝑘3U_{3},V_{3}\in\mathbb{R}^{L\times k_{3}}italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that U3V3p1(W¯)maxsubscriptnormsubscript𝑈3superscriptsubscript𝑉3topsubscript𝑝1¯𝑊absent\left\|U_{3}V_{3}^{\top}-p_{1}(\underline{W})\right\|_{\max}\leq∥ italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ ϵ/poly(L)italic-ϵpoly𝐿\epsilon/\mathrm{poly}(L)italic_ϵ / roman_poly ( italic_L ). In addition, it takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct U3,V3subscript𝑈3subscript𝑉3U_{3},V_{3}italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT.

Proof of Lemma G.6.

By tensor trick, we construct U3subscript𝑈3U_{3}italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, V3subscript𝑉3V_{3}italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT as tensor products of U1,V1subscript𝑈1subscript𝑉1U_{1},V_{1}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and U2,V2subscript𝑈2subscript𝑉2U_{2},V_{2}italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, respectively, while preserving their low-rank structures. Then, we show the low-rank approximation of p1()subscript𝑝1p_{1}(\cdot)italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ) with bounded error by Lemma G.3 and Lemma G.5.

Let \oslash be column-wise Kronecker product such that AB[A[,1]B[,1]A[,k1]B[,k1]]L×k1k2𝐴𝐵delimited-[]tensor-producttensor-product𝐴1𝐵1delimited-∣∣𝐴subscript𝑘1𝐵subscript𝑘1superscript𝐿subscript𝑘1subscript𝑘2A\oslash B\coloneqq[A[\cdot,1]\otimes B[\cdot,1]\mid\ldots\mid A[\cdot,k_{1}]% \otimes B[\cdot,k_{1}]]\in\mathbb{R}^{L\times k_{1}k_{2}}italic_A ⊘ italic_B ≔ [ italic_A [ ⋅ , 1 ] ⊗ italic_B [ ⋅ , 1 ] ∣ … ∣ italic_A [ ⋅ , italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ⊗ italic_B [ ⋅ , italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for AL×k1,BL×k2formulae-sequence𝐴superscript𝐿subscript𝑘1𝐵superscript𝐿subscript𝑘2A\in\mathbb{R}^{L\times k_{1}},B\in\mathbb{R}^{L\times k_{2}}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

Let f~(W¯)U1V1𝖳~𝑓¯𝑊subscript𝑈1superscriptsubscript𝑉1𝖳\widetilde{f}(\underline{W})\coloneqq U_{1}V_{1}^{\mathsf{T}}over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) ≔ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT and q~(W¯)U2V2𝖳~𝑞¯𝑊subscript𝑈2superscriptsubscript𝑉2𝖳\widetilde{q}(\underline{W})\coloneqq U_{2}V_{2}^{\mathsf{T}}over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) ≔ italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT denote matrix-multiplication approximations to f(W¯)𝑓¯𝑊f(\underline{W})italic_f ( under¯ start_ARG italic_W end_ARG ) and q(W¯)𝑞¯𝑊q(\underline{W})italic_q ( under¯ start_ARG italic_W end_ARG ), respectively.

For the case of presentation, let U3=U1L×k1U2L×k2subscript𝑈3superscriptsubscript𝑈1𝐿subscript𝑘1superscriptsubscript𝑈2𝐿subscript𝑘2U_{3}=\overbrace{U_{1}}^{L\times k_{1}}\oslash\overbrace{U_{2}}^{L\times k_{2}}italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = over⏞ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ⊘ over⏞ start_ARG italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and V3=V1L×k1V2L×k2subscript𝑉3superscriptsubscript𝑉1𝐿subscript𝑘1superscriptsubscript𝑉2𝐿subscript𝑘2V_{3}=\overbrace{V_{1}}^{L\times k_{1}}\oslash\overbrace{V_{2}}^{L\times k_{2}}italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = over⏞ start_ARG italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ⊘ over⏞ start_ARG italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. It holds

U3V3p1(W¯)maxsubscriptnormsubscript𝑈3superscriptsubscript𝑉3topsubscript𝑝1¯𝑊\displaystyle\leavevmode\nobreak\ \left\|U_{3}V_{3}^{\top}-p_{1}(\underline{W}% )\right\|_{\max}∥ italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
=\displaystyle== U3V3f(W¯)q(W¯)maxsubscriptnormsubscript𝑈3superscriptsubscript𝑉3topdirect-product𝑓¯𝑊𝑞¯𝑊\displaystyle\leavevmode\nobreak\ \left\|U_{3}V_{3}^{\top}-f(\underline{W})% \odot q(\underline{W})\right\|_{\max}∥ italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( By p1(W¯)=f(W¯)q(W¯)subscript𝑝1¯𝑊direct-product𝑓¯𝑊𝑞¯𝑊p_{1}(\underline{W})=f(\underline{W})\odot q(\underline{W})italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) = italic_f ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ))
=\displaystyle== (U1U2)(V1V2)f(W¯)q(W¯)maxsubscriptnormsubscript𝑈1subscript𝑈2superscriptsubscript𝑉1subscript𝑉2topdirect-product𝑓¯𝑊𝑞¯𝑊\displaystyle\leavevmode\nobreak\ \left\|\left(U_{1}\oslash U_{2}\right)\left(% V_{1}\oslash V_{2}\right)^{\top}-f(\underline{W})\odot q(\underline{W})\right% \|_{\max}∥ ( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊘ italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊘ italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
=\displaystyle== (U1V1)(U2V2)f(W¯)q(W¯)maxsubscriptnormdirect-productsubscript𝑈1superscriptsubscript𝑉1topsubscript𝑈2superscriptsubscript𝑉2topdirect-product𝑓¯𝑊𝑞¯𝑊\displaystyle\leavevmode\nobreak\ \left\|\left(U_{1}V_{1}^{\top}\right)\odot% \left(U_{2}V_{2}^{\top}\right)-f(\underline{W})\odot q(\underline{W})\right\|_% {\max}∥ ( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ⊙ ( italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) - italic_f ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
=\displaystyle== f~(W¯)q~(W¯)f(W¯)q(W¯)maxsubscriptnormdirect-product~𝑓¯𝑊~𝑞¯𝑊direct-product𝑓¯𝑊𝑞¯𝑊\displaystyle\leavevmode\nobreak\ \|\widetilde{f}(\underline{W})\odot% \widetilde{q}(\underline{W})-f(\underline{W})\odot q(\underline{W})\|_{\max}∥ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) ⊙ over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) - italic_f ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
\displaystyle\leq f~(W¯)q~(W¯)f~(W¯)q(W¯)maxϵ/poly(L)+f~(W¯)q(W¯)f(W¯)q(W¯)maxϵ/poly(L)subscriptsubscriptnormdirect-product~𝑓¯𝑊~𝑞¯𝑊direct-product~𝑓¯𝑊𝑞¯𝑊absentitalic-ϵpoly𝐿subscriptsubscriptnormdirect-product~𝑓¯𝑊𝑞¯𝑊direct-product𝑓¯𝑊𝑞¯𝑊absentitalic-ϵpoly𝐿\displaystyle\leavevmode\nobreak\ \underbrace{\|\widetilde{f}(\underline{W})% \odot\widetilde{q}(\underline{W})-\widetilde{f}(\underline{W})\odot q(% \underline{W})\|_{\max}}_{\leq\epsilon/\mathrm{poly}(L)}+\underbrace{\|% \widetilde{f}(\underline{W})\odot q(\underline{W})-f(\underline{W})\odot q(% \underline{W})\|_{\max}}_{\leq\epsilon/\mathrm{poly}(L)}under⏟ start_ARG ∥ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) ⊙ over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) - over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ) end_POSTSUBSCRIPT + under⏟ start_ARG ∥ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) - italic_f ( under¯ start_ARG italic_W end_ARG ) ⊙ italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ) end_POSTSUBSCRIPT
\displaystyle\leq ϵ/poly(L).italic-ϵpoly𝐿\displaystyle\leavevmode\nobreak\ \epsilon/\mathrm{poly}(L).italic_ϵ / roman_poly ( italic_L ) . (By Lemma G.3 and Lemma G.5)

Computationally, by k1,k2=Lo(1)subscript𝑘1subscript𝑘2superscript𝐿𝑜1k_{1},k_{2}=L^{o(1)}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT, computing U3subscript𝑈3U_{3}italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT and V3subscript𝑉3V_{3}italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time. This completes the proof. ∎

Lemma G.7 (Approximate p2()subscript𝑝2p_{2}(\cdot)italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ )).

Let k1,k2,k4=Lo(1)subscript𝑘1subscript𝑘2subscript𝑘4superscript𝐿𝑜1k_{1},k_{2},k_{4}=L^{o(1)}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT. Let p2(W¯)L×Lsubscript𝑝2¯𝑊superscript𝐿𝐿p_{2}(\underline{W})\in\mathbb{R}^{L\times L}italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT follow Definition G.7 such that its j0subscript𝑗0j_{0}italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th column is p2(W¯)j0=f(W¯)j0f(W¯)j0q(W¯)j0subscript𝑝2subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑓superscriptsubscript¯𝑊subscript𝑗0top𝑞subscript¯𝑊subscript𝑗0p_{2}(\underline{W})_{j_{0}}=f(\underline{W})_{j_{0}}f(\underline{W})_{j_{0}}^% {\top}q(\underline{W})_{j_{0}}italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT for each j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ]. Suppose U1,V1L×k1subscript𝑈1subscript𝑉1superscript𝐿subscript𝑘1U_{1},V_{1}\in\mathbb{R}^{L\times k_{1}}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT approximates the f(𝐗)f𝐗\mathrm{f}(\mathrm{\mathbf{X}})roman_f ( bold_X ) such that U1V1f(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈1superscriptsubscript𝑉1top𝑓¯𝑊italic-ϵpoly𝐿\left\|U_{1}V_{1}^{\top}-f(\underline{W})\right\|_{\max}\leq\epsilon/\mathrm{% poly}(L)∥ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ), and U2,V2L×k2subscript𝑈2subscript𝑉2superscript𝐿subscript𝑘2U_{2},V_{2}\in\mathbb{R}^{L\times k_{2}}italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT approximates the q(W¯)L×L𝑞¯𝑊superscript𝐿𝐿q(\underline{W})\in\mathbb{R}^{L\times L}italic_q ( under¯ start_ARG italic_W end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_L end_POSTSUPERSCRIPT such that U2V2q(W¯)maxϵ/poly(L)subscriptnormsubscript𝑈2superscriptsubscript𝑉2top𝑞¯𝑊italic-ϵpoly𝐿\left\|U_{2}V_{2}^{\top}-q(\underline{W})\right\|_{\max}\leq\epsilon/\mathrm{% poly}(L)∥ italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_q ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ). Then there exist matrices U4,V4L×k4subscript𝑈4subscript𝑉4superscript𝐿subscript𝑘4U_{4},V_{4}\in\mathbb{R}^{L\times k_{4}}italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that U4V4p2(¯)maxϵ/poly(L)subscriptnormsubscript𝑈4superscriptsubscript𝑉4topsubscript𝑝2¯absentitalic-ϵpoly𝐿\left\|U_{4}V_{4}^{\top}-p_{2}(\underline{})\right\|_{\max}\leq\epsilon/% \mathrm{poly}(L)∥ italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≤ italic_ϵ / roman_poly ( italic_L ). In addition, it takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct U4,V4subscript𝑈4subscript𝑉4U_{4},V_{4}italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT.

Proof of Lemma G.7.

From Definition G.7,

p2(W¯)j0f(W¯)j0f(W¯)j0q(W¯)j0(I)(II).subscript𝑝2subscript¯𝑊subscript𝑗0superscript𝑓subscript¯𝑊subscript𝑗0subscript𝑓superscriptsubscript¯𝑊subscript𝑗0top𝑞subscript¯𝑊subscript𝑗0𝐼𝐼𝐼\displaystyle p_{2}(\underline{W})_{j_{0}}\coloneqq\overbrace{f\left(% \underline{W}\right)_{j_{0}}\underbrace{f\left(\underline{W}\right)_{j_{0}}^{% \top}q(\underline{W})_{j_{0}}}_{(I)}}^{(II)}.italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≔ over⏞ start_ARG italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT ( italic_I ) end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT ( italic_I italic_I ) end_POSTSUPERSCRIPT .

For (I), we show its low-rank approximation by observing the low-rank-preserving property of the multiplication between f()𝑓f(\cdot)italic_f ( ⋅ ) and q()𝑞q(\cdot)italic_q ( ⋅ ) (from Lemma G.3 and Lemma G.5). For (II), we show its low-rank approximation by the low-rank structure of f()𝑓f(\cdot)italic_f ( ⋅ ) and (I).

Part (I).

We define a function r(W¯):d2L:𝑟¯𝑊superscriptsuperscript𝑑2superscript𝐿r(\underline{W}):\mathbb{R}^{d^{2}}\to\mathbb{R}^{L}italic_r ( under¯ start_ARG italic_W end_ARG ) : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT such that the j0subscript𝑗0j_{0}italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT-th component r(W¯)j0(f(W¯)j0)q(W¯)j0𝑟subscript¯𝑊subscript𝑗0superscript𝑓subscript¯𝑊subscript𝑗0top𝑞subscript¯𝑊subscript𝑗0r(\underline{W})_{j_{0}}\coloneqq\left(f(\underline{W})_{j_{0}}\right)^{\top}q% (\underline{W})_{j_{0}}italic_r ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≔ ( italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_q ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT for all j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ]. Let r~(W¯)~𝑟¯𝑊\widetilde{r}(\underline{W})over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) denote the approximation of r(W¯)𝑟¯𝑊r(\underline{W})italic_r ( under¯ start_ARG italic_W end_ARG ) via decomposing into f()𝑓f(\cdot)italic_f ( ⋅ ) and q()𝑞q(\cdot)italic_q ( ⋅ ):

r~(W¯)j0~𝑟subscript¯𝑊subscript𝑗0\displaystyle\widetilde{r}(\underline{W})_{j_{0}}over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT f~(W¯)j0,q~(W¯)j0=(U1V1)[j0,][(U2V2)[j0,]]absent~𝑓subscript¯𝑊subscript𝑗0~𝑞subscript¯𝑊subscript𝑗0subscript𝑈1superscriptsubscript𝑉1topsubscript𝑗0superscriptdelimited-[]subscript𝑈2superscriptsubscript𝑉2topsubscript𝑗0top\displaystyle\coloneqq\left\langle\widetilde{f}(\underline{W})_{j_{0}},% \widetilde{q}(\underline{W})_{j_{0}}\right\rangle=\left(U_{1}V_{1}^{\top}% \right)[j_{0},\cdot]\cdot\left[\left(U_{2}V_{2}^{\top}\right)[j_{0},\cdot]% \right]^{\top}≔ ⟨ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , over~ start_ARG italic_q end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟩ = ( italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) [ italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , ⋅ ] ⋅ [ ( italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) [ italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , ⋅ ] ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
=U1[j0,]V1k1×LV2L×k2(U2[j0,]),absentsubscript𝑈1subscript𝑗0subscriptsuperscriptsubscript𝑉1topsubscript𝑘1𝐿subscriptsubscript𝑉2𝐿subscript𝑘2superscriptsubscript𝑈2subscript𝑗0top\displaystyle=U_{1}[j_{0},\cdot]\underbrace{V_{1}^{\top}}_{{k_{1}\times L}}% \underbrace{V_{2}}_{{L\times k_{2}}}\left(U_{2}[j_{0},\cdot]\right)^{\top},= italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , ⋅ ] under⏟ start_ARG italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT [ italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , ⋅ ] ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (G.6)

for all j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ]. This allows us to write p2(W¯)=f(W¯)diag(r(W¯))subscript𝑝2¯𝑊𝑓¯𝑊diag𝑟¯𝑊{p}_{2}(\underline{W})={f}(\underline{W})\mathop{\rm{diag}}({r}(\underline{W}))italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) = italic_f ( under¯ start_ARG italic_W end_ARG ) roman_diag ( italic_r ( under¯ start_ARG italic_W end_ARG ) ) with diag(r~(W¯))diag~𝑟¯𝑊\mathop{\rm{diag}}(\widetilde{r}(\underline{W}))roman_diag ( over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) ) denoting a diagonal matrix with diagonal entries being components of r~(W¯)~𝑟¯𝑊\widetilde{r}(\underline{W})over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ).

Part (II).

With r()𝑟r(\cdot)italic_r ( ⋅ ), we approximate p2()subscript𝑝2p_{2}(\cdot)italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ ) with p~2(W¯)=f~(W¯)diag(r~(W¯))subscript~𝑝2¯𝑊~𝑓¯𝑊diag~𝑟¯𝑊\widetilde{p}_{2}(\underline{W})=\widetilde{f}(\underline{W})\mathop{\rm{diag}% }(\widetilde{r}(\underline{W}))over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) = over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) roman_diag ( over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) ) as follows.

Since f~(W¯)~𝑓¯𝑊\widetilde{f}(\underline{W})over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) has low rank representation, and diag(r~(W¯))diag~𝑟¯𝑊\mathop{\rm{diag}}(\widetilde{r}(\underline{W}))roman_diag ( over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) ) is a diagonal matrix, p~2()subscript~𝑝2\widetilde{p}_{2}(\cdot)over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ ) has low-rank representation by definition. Thus, we set p~2(W¯)=U4V4𝖳subscript~𝑝2¯𝑊subscript𝑈4superscriptsubscript𝑉4𝖳\widetilde{p}_{2}(\underline{W})=U_{4}V_{4}^{\mathsf{T}}over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) = italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT with U4=U1subscript𝑈4subscript𝑈1U_{4}=U_{1}italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and V4=diag(r~(W¯))V1subscript𝑉4diag~𝑟¯𝑊subscript𝑉1V_{4}=\mathop{\rm{diag}}(\widetilde{r}(\underline{W}))V_{1}italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = roman_diag ( over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) ) italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Then, we bound the approximation error

U4V4p2(W¯)maxsubscriptnormsubscript𝑈4superscriptsubscript𝑉4topsubscript𝑝2¯𝑊\displaystyle\leavevmode\nobreak\ \left\|U_{4}V_{4}^{\top}-p_{2}(\underline{W}% )\right\|_{\max}∥ italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
=\displaystyle== p~2(W¯)p2(W¯)maxsubscriptnormsubscript~𝑝2¯𝑊subscript𝑝2¯𝑊\displaystyle\leavevmode\nobreak\ \left\|\widetilde{p}_{2}(\underline{W})-p_{2% }(\underline{W})\right\|_{\max}∥ over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
=\displaystyle== maxj0[L]f~(W¯)j0r~(W¯)j0f(W¯)j0r(W¯)j0maxsubscriptsubscript𝑗0delimited-[]𝐿subscriptnorm~𝑓subscript¯𝑊subscript𝑗0~𝑟subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑟subscript¯𝑊subscript𝑗0\displaystyle\leavevmode\nobreak\ \max_{j_{0}\in[L]}\left\|{\widetilde{f}(% \underline{W})_{j_{0}}\widetilde{r}(\underline{W})_{j_{0}}-f(\underline{W})_{j% _{0}}r(\underline{W})_{j_{0}}}\right\|_{\max}roman_max start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ] end_POSTSUBSCRIPT ∥ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_r ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
\displaystyle\leq maxj0[L][f~(W¯)j0r~(W¯)j0f(W¯)j0r(W¯)j0max+f~(W¯)j0r~(W¯)j0f(W¯)j0r(W¯)j0max]subscriptsubscript𝑗0delimited-[]𝐿subscriptnorm~𝑓subscript¯𝑊subscript𝑗0~𝑟subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑟subscript¯𝑊subscript𝑗0subscriptnorm~𝑓subscript¯𝑊subscript𝑗0~𝑟subscript¯𝑊subscript𝑗0𝑓subscript¯𝑊subscript𝑗0𝑟subscript¯𝑊subscript𝑗0\displaystyle\leavevmode\nobreak\ \max_{j_{0}\in[L]}\left[\left\|\widetilde{f}% (\underline{W})_{j_{0}}\widetilde{r}(\underline{W})_{j_{0}}-f(\underline{W})_{% j_{0}}{r}(\underline{W})_{j_{0}}\right\|_{\max}+\left\|\widetilde{f}(% \underline{W})_{j_{0}}\widetilde{r}(\underline{W})_{j_{0}}-f(\underline{W})_{j% _{0}}r(\underline{W})_{j_{0}}\right\|_{\max}\right]roman_max start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ] end_POSTSUBSCRIPT [ ∥ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_r ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT + ∥ over~ start_ARG italic_f end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_f ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_r ( under¯ start_ARG italic_W end_ARG ) start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ] (By triangle inequality)
\displaystyle\leq ϵ/poly(L).italic-ϵpoly𝐿\displaystyle\leavevmode\nobreak\ \epsilon/\mathrm{poly}(L).italic_ϵ / roman_poly ( italic_L ) .

Computationally, computing V1V2superscriptsubscript𝑉1topsubscript𝑉2V_{1}^{\top}V_{2}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time by k1,k2=Lo(1)subscript𝑘1subscript𝑘2superscript𝐿𝑜1k_{1},k_{2}=L^{o(1)}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT. Once we have V1V2superscriptsubscript𝑉1topsubscript𝑉2V_{1}^{\top}V_{2}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT precomputed, (G.6) only takes O(k1k2)𝑂subscript𝑘1subscript𝑘2O(k_{1}k_{2})italic_O ( italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) time for each j0[L]subscript𝑗0delimited-[]𝐿j_{0}\in[L]italic_j start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ [ italic_L ]. Thus, the total time is O(Lk1k2)=L1+o(1)𝑂𝐿subscript𝑘1subscript𝑘2superscript𝐿1𝑜1O\left(Lk_{1}k_{2}\right)=L^{1+o(1)}italic_O ( italic_L italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT. Since U1subscript𝑈1U_{1}italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and V1subscript𝑉1V_{1}italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct and V4=diag(r~(W¯))L×LV1L×k1subscript𝑉4subscriptdiag~𝑟¯𝑊𝐿𝐿subscriptsubscript𝑉1𝐿subscript𝑘1V_{4}=\underbrace{\mathop{\rm{diag}}(\widetilde{r}(\underline{W}))}_{L\times L% }\underbrace{V_{1}}_{L\times k_{1}}italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = under⏟ start_ARG roman_diag ( over~ start_ARG italic_r end_ARG ( under¯ start_ARG italic_W end_ARG ) ) end_ARG start_POSTSUBSCRIPT italic_L × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT also takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time, U4subscript𝑈4U_{4}italic_U start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT and V4subscript𝑉4V_{4}italic_V start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct. This completes the proof. ∎

G.2 Proof of Theorem 4.1

Proof of Theorem 4.1.

By the definitions of matrices p()𝑝p(\cdot)italic_p ( ⋅ ), p1()subscript𝑝1p_{1}(\cdot)italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ) and p2()subscript𝑝2p_{2}(\cdot)italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ⋅ ) (Definition G.7), we have

p(W¯)=p1(W¯)p2(W¯).𝑝¯𝑊subscript𝑝1¯𝑊subscript𝑝2¯𝑊\displaystyle p(\underline{W})=p_{1}(\underline{W})-p_{2}(\underline{W}).italic_p ( under¯ start_ARG italic_W end_ARG ) = italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) .

By Lemma G.2, we have

dg2dW¯=vec(A1p(W¯)A2).derivative¯𝑊subscript𝑔2vecsubscript𝐴1𝑝¯𝑊superscriptsubscript𝐴2top\displaystyle\derivative{g_{2}}{\underline{W}}=\operatorname{vec}\left(A_{1}p(% \underline{W})A_{2}^{\top}\right).divide start_ARG roman_d start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG = roman_vec ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (G.7)

To show the existence of L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT algorithms for DiT backward computation Problem 1, we prove fast low-rank approximations for A1p1(W¯)A2subscript𝐴1subscript𝑝1¯𝑊superscriptsubscript𝐴2topA_{1}p_{1}(\underline{W})A_{2}^{\top}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and A1p2(W¯)A2subscript𝐴1subscript𝑝2¯𝑊superscriptsubscript𝐴2topA_{1}p_{2}(\underline{W})A_{2}^{\top}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT as follows.

Let p~1(W¯),p2~(W¯)subscript~𝑝1¯𝑊~subscript𝑝2¯𝑊\widetilde{p}_{1}(\underline{W}),\widetilde{p_{2}}(\underline{W})over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) , over~ start_ARG italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ( under¯ start_ARG italic_W end_ARG ) denote the approximations to p1(W¯),p2(W¯)subscript𝑝1¯𝑊subscript𝑝2¯𝑊p_{1}(\underline{W}),p_{2}(\underline{W})italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) , italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ), respectively.

By Lemma G.6, it takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT time to construct U3,V3L×k3subscript𝑈3subscript𝑉3superscript𝐿subscript𝑘3U_{3},V_{3}\in\mathbb{R}^{L\times k_{3}}italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that

A1p~1(W¯)A2=A1U3V3A2.subscript𝐴1subscript~𝑝1¯𝑊superscriptsubscript𝐴2topsubscript𝐴1subscript𝑈3superscriptsubscript𝑉3topsuperscriptsubscript𝐴2top\displaystyle A_{1}\widetilde{p}_{1}(\underline{W})A_{2}^{\top}=A_{1}U_{3}V_{3% }^{\top}A_{2}^{\top}.italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .

Then, computing A1d×LU3L×k3V3k3×LA2L×dsubscriptsubscript𝐴1𝑑𝐿subscriptsubscript𝑈3𝐿subscript𝑘3subscriptsuperscriptsubscript𝑉3topsubscript𝑘3𝐿subscriptsuperscriptsubscript𝐴2top𝐿𝑑\underbrace{A_{1}}_{d\times L}\underbrace{U_{3}}_{L\times k_{3}}\underbrace{V_% {3}^{\top}}_{k_{3}\times L}\underbrace{A_{2}^{\top}}_{L\times d}under⏟ start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_d × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_U start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_k start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUBSCRIPT under⏟ start_ARG italic_V start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT × italic_L end_POSTSUBSCRIPT under⏟ start_ARG italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT italic_L × italic_d end_POSTSUBSCRIPT takes L1+o(1)superscript𝐿1𝑜1L^{1+o(1)}italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT due to the fact that d,k1k3=Lo(1)𝑑subscript𝑘1subscript𝑘3superscript𝐿𝑜1d,k_{1}k_{3}=L^{o(1)}italic_d , italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT.

Therefore, total running time for A1p1(W¯)A2subscript𝐴1subscript𝑝1¯𝑊superscriptsubscript𝐴2topA_{1}p_{1}(\underline{W})A_{2}^{\top}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is LLo(1)=L1+o(1)𝐿superscript𝐿𝑜1superscript𝐿1𝑜1L\cdot L^{o(1)}=L^{1+o(1)}italic_L ⋅ italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT = italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT.

For the same reason (by Lemma G.7), total running time for A1p2(W¯)A2subscript𝐴1subscript𝑝2¯𝑊superscriptsubscript𝐴2topA_{1}p_{2}(\underline{W})A_{2}^{\top}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is LLo(1)=L1+o(1)𝐿superscript𝐿𝑜1superscript𝐿1𝑜1L\cdot L^{o(1)}=L^{1+o(1)}italic_L ⋅ italic_L start_POSTSUPERSCRIPT italic_o ( 1 ) end_POSTSUPERSCRIPT = italic_L start_POSTSUPERSCRIPT 1 + italic_o ( 1 ) end_POSTSUPERSCRIPT.

Lastly, we have

g2W¯G~(W)maxsubscriptnormpartial-derivative¯𝑊subscript𝑔2superscript~𝐺𝑊\displaystyle\leavevmode\nobreak\ \left\|\partialderivative{g_{2}}{\underline{% W}}-\widetilde{G}^{(W)}\right\|_{\max}∥ divide start_ARG ∂ start_ARG italic_g start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG start_ARG ∂ start_ARG under¯ start_ARG italic_W end_ARG end_ARG end_ARG - over~ start_ARG italic_G end_ARG start_POSTSUPERSCRIPT ( italic_W ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT
=\displaystyle== vec(A1p~(W¯)A2)vec(A1p~(W¯)A2)maxsubscriptnormvecsubscript𝐴1~𝑝¯𝑊superscriptsubscript𝐴2topvecsubscript𝐴1~𝑝¯𝑊superscriptsubscript𝐴2top\displaystyle\leavevmode\nobreak\ \left\|\operatorname{vec}\left(A_{1}% \widetilde{p}(\underline{W})A_{2}^{\top}\right)-\operatorname{vec}\left(A_{1}% \widetilde{p}(\underline{W})A_{2}^{\top}\right)\right\|_{\max}∥ roman_vec ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) - roman_vec ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT (By Lemma G.2)
=\displaystyle== (A1p~(W¯)A2)(A1p~(W¯)A2)maxsubscriptnormsubscript𝐴1~𝑝¯𝑊superscriptsubscript𝐴2topsubscript𝐴1~𝑝¯𝑊superscriptsubscript𝐴2top\displaystyle\leavevmode\nobreak\ \left\|\left(A_{1}\widetilde{p}(\underline{W% })A_{2}^{\top}\right)-\left(A_{1}\widetilde{p}(\underline{W})A_{2}^{\top}% \right)\right\|_{\max}∥ ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) - ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG ( under¯ start_ARG italic_W end_ARG ) italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT (By definition, Amaxmaxi,j|Aij|subscriptnorm𝐴subscript𝑖𝑗subscript𝐴𝑖𝑗\norm{A}_{\max}\coloneqq\max_{i,j}\absolutevalue{A_{ij}}∥ start_ARG italic_A end_ARG ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≔ roman_max start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT | start_ARG italic_A start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG | for any matrix A𝐴Aitalic_A)
\displaystyle\leq (A1[p1(W¯)p~1(W¯)]A2)max+(A1[p2(W¯)p~2(W¯)]A2)maxsubscriptnormsubscript𝐴1delimited-[]subscript𝑝1¯𝑊subscript~𝑝1¯𝑊superscriptsubscript𝐴2topsubscriptnormsubscript𝐴1delimited-[]subscript𝑝2¯𝑊subscript~𝑝2¯𝑊superscriptsubscript𝐴2top\displaystyle\leavevmode\nobreak\ \left\|\left(A_{1}\left[p_{1}(\underline{W})% -\widetilde{p}_{1}(\underline{W})\right]A_{2}^{\top}\right)\right\|_{\max}+% \left\|\left(A_{1}\left[p_{2}(\underline{W})-\widetilde{p}_{2}(\underline{W})% \right]A_{2}^{\top}\right)\right\|_{\max}∥ ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ] italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT + ∥ ( italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ] italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT (By Definition G.7 and triangle inequality)
\displaystyle\leq A1A2((p1(W¯)p~1(W¯))max+(p2(W¯)p~2(W¯))max)subscriptnormsubscript𝐴1subscriptnormsubscript𝐴2subscriptnormsubscript𝑝1¯𝑊subscript~𝑝1¯𝑊subscriptnormsubscript𝑝2¯𝑊subscript~𝑝2¯𝑊\displaystyle\leavevmode\nobreak\ \norm{A_{1}}_{\infty}\norm{A_{2}}_{\infty}% \left(\left\|\left(p_{1}(\underline{W})-\widetilde{p}_{1}(\underline{W})\right% )\right\|_{\max}+\left\|\left(p_{2}(\underline{W})-\widetilde{p}_{2}(% \underline{W})\right)\right\|_{\max}\right)∥ start_ARG italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ∥ start_ARG italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( ∥ ( italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT + ∥ ( italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) - over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( under¯ start_ARG italic_W end_ARG ) ) ∥ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) (By the sub-multiplicative property of subscriptnorm\norm{\cdot}_{\infty}∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT)
\displaystyle\leq ϵ/poly(L).italic-ϵpoly𝐿\displaystyle\leavevmode\nobreak\ \epsilon/\mathrm{poly}(L).italic_ϵ / roman_poly ( italic_L ) . (By Lemma G.6 and Lemma G.7)

Set ϵ=1/poly(L)italic-ϵ1poly𝐿\epsilon=1/\mathrm{poly}(L)italic_ϵ = 1 / roman_poly ( italic_L ). We complete the proof. ∎

References

  • Alman and Song [2023] Josh Alman and Zhao Song. Fast attention requires bounded entries. Advances in Neural Information Processing Systems (NeurIPS), 36, 2023.
  • Alman and Song [2024a] Josh Alman and Zhao Song. The fine-grained complexity of gradient computation for training large language models. arXiv preprint arXiv:2402.04497, 2024a.
  • Alman and Song [2024b] Josh Alman and Zhao Song. How to capture higher-order correlations? generalizing matrix softmax attention to kronecker computation. In The Twelfth International Conference on Learning Representations (ICLR), 2024b.
  • Ambrogioni [2023] Luca Ambrogioni. In search of dispersed memories: Generative diffusion models are associative memory networks. arXiv preprint arXiv:2309.17290, 2023.
  • Bao et al. [2022] Fan Bao, Chongxuan Li, Yue Cao, and Jun Zhu. All are worth words: a vit backbone for score-based diffusion models. In NeurIPS 2022 Workshop on Score-Based Methods, 2022.
  • Benton et al. [2024] Joe Benton, Valentin De Bortoli, Arnaud Doucet, and George Deligiannidis. Nearly d-linear convergence bounds for diffusion models via stochastic localization. In The Twelfth International Conference on Learning Representations (ICLR), 2024.
  • Bortoli [2022] Valentin De Bortoli. Convergence of denoising diffusion models under the manifold hypothesis. Transactions on Machine Learning Research, 2022. ISSN 2835-8856.
  • Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Chen et al. [2024] Junsong Chen, **cheng YU, Chongjian GE, Lewei Yao, Enze Xie, Zhongdao Wang, James Kwok, ** Luo, Huchuan Lu, and Zhenguo Li. Pixart-$\alpha$: Fast training of diffusion transformer for photorealistic text-to-image synthesis. In The Twelfth International Conference on Learning Representations (ICLR), 2024.
  • Chen et al. [2020a] Minshuo Chen, Xingguo Li, and Tuo Zhao. On generalization bounds of a family of recurrent neural networks. In Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics (AISTATS), volume 108, pages 1233–1243, 2020a.
  • Chen et al. [2020b] Minshuo Chen, Wen**g Liao, Hongyuan Zha, and Tuo Zhao. Distribution approximation and statistical estimation guarantees of generative adversarial networks. arXiv preprint arXiv:2002.03938, 2020b.
  • Chen et al. [2023a] Minshuo Chen, Kaixuan Huang, Tuo Zhao, and Mengdi Wang. Score approximation, estimation and distribution recovery of diffusion models on low-dimensional data. In International Conference on Machine Learning (ICML), pages 4672–4712. PMLR, 2023a.
  • Chen et al. [2023b] Sitan Chen, Sinho Chewi, Jerry Li, Yuanzhi Li, Adil Salim, and Anru Zhang. Sampling is as easy as learning the score: theory for diffusion models with minimal data assumptions. In The Eleventh International Conference on Learning Representations (ICLR), 2023b.
  • Cygan et al. [2016] Marek Cygan, Holger Dell, Daniel Lokshtanov, Dániel Marx, Jesper Nederlof, Yoshio Okamoto, Ramamohan Paturi, Saket Saurabh, and Magnus Wahlström. On problems as hard as cnf-sat. ACM Transactions on Algorithms (TALG), 12(3):1–24, 2016.
  • Diao et al. [2018] Huaian Diao, Zhao Song, Wen Sun, and David Woodruff. Sketching for kronecker product regression and p-splines. In International Conference on Artificial Intelligence and Statistics (AISTATS), pages 1299–1308. PMLR, 2018.
  • Diao et al. [2019] Huaian Diao, Rajesh Jayaram, Zhao Song, Wen Sun, and David Woodruff. Optimal sketching for kronecker product regression and low rank approximation. Advances in neural information processing systems (NeurIPS), 32, 2019.
  • Edelman et al. [2022] Benjamin L Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation in self-attention mechanisms. In International Conference on Machine Learning (ICML), pages 5793–5831. PMLR, 2022.
  • Esser et al. [2024] Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, et al. Scaling rectified flow transformers for high-resolution image synthesis. arXiv preprint arXiv:2403.03206, 2024.
  • Floridi and Chiriatti [2020] Luciano Floridi and Massimo Chiriatti. Gpt-3: Its nature, scope, limits, and consequences. Minds and Machines, 30:681–694, 2020.
  • Gao et al. [2023a] Shanghua Gao, Pan Zhou, Ming-Ming Cheng, and Shuicheng Yan. Masked diffusion transformer is a strong image synthesizer. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 23164–23173, 2023a.
  • Gao et al. [2023b] Yeqi Gao, Zhao Song, Weixin Wang, and Junze Yin. A fast optimization view: Reformulating single layer attention in llm based on tensor and svm trick, and solving it in matrix multiplication time. arXiv preprint arXiv:2309.07418, 2023b.
  • Gao et al. [2023c] Yeqi Gao, Zhao Song, and Shenghao Xie. In-context learning for attention scheme: from single softmax regression to multiple softmax regression via a tensor trick. arXiv preprint arXiv:2307.02419, 2023c.
  • Gu et al. [2024] Jiuxiang Gu, Yingyu Liang, Zhenmei Shi, Zhao Song, and Yufa Zhou. Tensor attention training: Provably efficient learning of higher-order transformers. arXiv preprint arXiv:2405.16411, 2024.
  • Guan et al. [2024] Jiaqi Guan, Xiangxin Zhou, Yuwei Yang, Yu Bao, Jian Peng, Jianzhu Ma, Qiang Liu, Liang Wang, and Quanquan Gu. Decompdiff: diffusion models with decomposed priors for structure-based drug design. arXiv preprint arXiv:2403.07902, 2024.
  • Ho et al. [2020] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840–6851, 2020.
  • Hoover et al. [2023] Benjamin Hoover, Hendrik Strobelt, Dmitry Krotov, Judy Hoffman, Zsolt Kira, and Duen Horng Chau. Memory in plain sight: A survey of the uncanny resemblances between diffusion models and associative memories. arXiv preprint arXiv:2309.16750, 2023.
  • Hu et al. [2023] Jerry Yao-Chieh Hu, Donglin Yang, Dennis Wu, Chenwei Xu, Bo-Yu Chen, and Han Liu. On sparse modern hopfield model. In Thirty-seventh Conference on Neural Information Processing Systems (NeurIPS), 2023.
  • Hu et al. [2024a] Jerry Yao-Chieh Hu, Pei-Hsuan Chang, Haozheng Luo, Hong-Yu Chen, Weijian Li, Wei-Po Wang, and Han Liu. Outlier-efficient hopfield layers for large transformer-based models. In Forty-first International Conference on Machine Learning (ICML), 2024a.
  • Hu et al. [2024b] Jerry Yao-Chieh Hu, Bo-Yu Chen, Dennis Wu, Feng Ruan, and Han Liu. Nonparametric modern hopfield models. arXiv preprint arXiv:2404.03900, 2024b.
  • Hu et al. [2024c] Jerry Yao-Chieh Hu, Thomas Lin, Zhao Song, and Han Liu. On computational limits of modern hopfield models: A fine-grained complexity analysis. In Forty-first International Conference on Machine Learning (ICML), 2024c.
  • Impagliazzo and Paturi [2001] Russell Impagliazzo and Ramamohan Paturi. On the complexity of k-sat. Journal of Computer and System Sciences, 62(2):367–375, 2001.
  • Ji et al. [2021] Yanrong Ji, Zhihan Zhou, Han Liu, and Ramana V Davuluri. Dnabert: pre-trained bidirectional encoder representations from transformers model for dna-language in genome. Bioinformatics, 37(15):2112–2120, 2021.
  • Jiang and Li [2023] Haotian Jiang and Qianxiao Li. Approximation theory of transformer networks for sequence modeling. arXiv preprint arXiv:2305.18475, 2023.
  • Kajitsuka and Sato [2023] Tokio Kajitsuka and Issei Sato. Are transformers with one layer self-attention using low-rank weight matrices universal approximators? arXiv preprint arXiv:2307.14023, 2023.
  • Kim et al. [2022] Junghwan Kim, Michelle Kim, and Barzan Mozafari. Provable memorization capacity of transformers. In The Eleventh International Conference on Learning Representations (ICLR), 2022.
  • Lagler et al. [2013] Klemens Lagler, Michael Schindelegger, Johannes Böhm, Hana Krásná, and Tobias Nilsson. Gpt2: Empirical slant delay model for radio space geodetic techniques. Geophysical research letters, 40(6):1069–1073, 2013.
  • Liu et al. [2024] Yixin Liu, Kai Zhang, Yuan Li, Zhiling Yan, Chujie Gao, Ruoxi Chen, Zhengqing Yuan, Yue Huang, Hanchi Sun, Jianfeng Gao, Lifang He, and Lichao Sun. Sora: A review on background, technology, limitations, and opportunities of large vision models, 2024.
  • Liu et al. [2021] Zhonghua Liu, Yue Lu, Zhihui Lai, Weihua Ou, and Kaibing Zhang. Robust sparse low-rank embedding for image dimension reduction. Applied Soft Computing, 113:107907, 2021.
  • Luo et al. [2023] Zhengxiong Luo, Dayou Chen, Yingya Zhang, Yan Huang, Liang Wang, Yujun Shen, Deli Zhao, **gren Zhou, and Tieniu Tan. Videofusion: Decomposed diffusion models for high-quality video generation. In 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 10209–10218. IEEE, 2023.
  • Ma et al. [2024] Nanye Ma, Mark Goldstein, Michael S Albergo, Nicholas M Boffi, Eric Vanden-Eijnden, and Saining Xie. Sit: Exploring flow and diffusion-based generative models with scalable interpolant transformers. arXiv preprint arXiv:2401.08740, 2024.
  • Mahdavi et al. [2023] Sadegh Mahdavi, Renjie Liao, and Christos Thrampoulidis. Memorization capacity of multi-head attention in transformers. arXiv preprint arXiv:2306.02010, 2023.
  • Mo et al. [2023] Shentong Mo, Enze Xie, Ruihang Chu, Lanqing Hong, Matthias Niessner, and Zhenguo Li. Dit-3d: Exploring plain diffusion transformers for 3d shape generation. Advances in Neural Information Processing Systems (NeurIPS), 36, 2023.
  • Nichol et al. [2021] Alex Nichol, Prafulla Dhariwal, Aditya Ramesh, Pranav Shyam, Pamela Mishkin, Bob McGrew, Ilya Sutskever, and Mark Chen. Glide: Towards photorealistic image generation and editing with text-guided diffusion models. arXiv preprint arXiv:2112.10741, 2021.
  • Oko et al. [2023] Kazusato Oko, Shunta Akiyama, and Taiji Suzuki. Diffusion models are minimax optimal distribution estimators. In International Conference on Machine Learning (ICML), pages 26517–26582. PMLR, 2023.
  • Peebles and Xie [2023] William Peebles and Saining Xie. Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pages 4195–4205, 2023.
  • Pope et al. [2021] Phillip Pope, Chen Zhu, Ahmed Abdelkader, Micah Goldblum, and Tom Goldstein. The intrinsic dimension of images and its impact on learning. arXiv preprint arXiv:2104.08894, 2021.
  • Ramesh et al. [2022] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, and Mark Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 1(2):3, 2022.
  • Ramsauer et al. [2020] Hubert Ramsauer, Bernhard Schafl, Johannes Lehner, Philipp Seidl, Michael Widrich, Thomas Adler, Lukas Gruber, Markus Holzleitner, Milena Pavlovic, Geir Kjetil Sandve, et al. Hopfield networks is all you need. arXiv preprint arXiv:2008.02217, 2020.
  • Rombach et al. [2022] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR), pages 10684–10695, 2022.
  • Song and Ermon [2019] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems (NeurIPS), 32, 2019.
  • Su and Wu [2018] Bing Su and Ying Wu. Learning low-dimensional temporal representations. In International Conference on Machine Learning (ICML), pages 4761–4770. PMLR, 2018.
  • Vahdat et al. [2021] Arash Vahdat, Karsten Kreis, and Jan Kautz. Score-based generative modeling in latent space. In Advances in Neural Information Processing Systems (NeurIPS), volume 34, pages 11287–11302, 2021.
  • Wang et al. [2024a] Xinyou Wang, Zaixiang Zheng, Fei Ye, Dongyu Xue, Shujian Huang, and Quanquan Gu. Diffusion language models are versatile protein learners. arXiv preprint arXiv:2402.18567, 2024a.
  • Wang et al. [2024b] Yan Wang, Lihao Wang, Yuning Shen, Yiqun Wang, Huizhuo Yuan, Yue Wu, and Quanquan Gu. Protein conformation generation via force-guided se (3) diffusion models. arXiv preprint arXiv:2403.14088, 2024b.
  • Wang et al. [2023] Yihan Wang, Jatin Chauhan, Wei Wang, and Cho-Jui Hsieh. Universality and limitations of prompt tuning. Advances in Neural Information Processing Systems (NeurIPS), 36, 2023.
  • Wibisono et al. [2024] Andre Wibisono, Yihong Wu, and Kaylee Yingxi Yang. Optimal score estimation via empirical bayes smoothing. arXiv preprint arXiv:2402.07747, 2024.
  • Williams [2018] Virginia Vassilevska Williams. On some fine-grained questions in algorithms and complexity. In Proceedings of the international congress of mathematicians: Rio de janeiro 2018, pages 3447–3487. World Scientific, 2018.
  • Wu et al. [2024a] Dennis Wu, Jerry Yao-Chieh Hu, Teng-Yun Hsiao, and Han Liu. Uniform memory retrieval with larger capacity for modern hopfield models. In Forty-first International Conference on Machine Learning (ICML), 2024a.
  • Wu et al. [2024b] Dennis Wu, Jerry Yao-Chieh Hu, Weijian Li, Bo-Yu Chen, and Han Liu. STanhop: Sparse tandem hopfield model for memory-enhanced time series prediction. In The Twelfth International Conference on Learning Representations (ICLR), 2024b.
  • Yun et al. [2020] Chulhee Yun, Srinadh Bhojanapalli, Ankit Singh Rawat, Sashank Reddi, and Sanjiv Kumar. Are transformers universal approximators of sequence-to-sequence functions? In International Conference on Learning Representations (ICLR), 2020.
  • Zheng et al. [2023] Hongkai Zheng, Weili Nie, Arash Vahdat, and Anima Anandkumar. Fast training of diffusion models with masked transformers. arXiv preprint arXiv:2306.09305, 2023.
  • Zhou et al. [2024a] Xiangxin Zhou, Xiwei Cheng, Yuwei Yang, Yu Bao, Liang Wang, and Quanquan Gu. Decompopt: Controllable and decomposed diffusion models for structure-based molecular optimization. arXiv preprint arXiv:2403.13829, 2024a.
  • Zhou et al. [2024b] Xiangxin Zhou, Dongyu Xue, Ruizhe Chen, Zaixiang Zheng, Liang Wang, and Quanquan Gu. Antigen-specific antibody design via direct energy-based preference optimization. arXiv preprint arXiv:2403.16576, 2024b.
  • Zhou et al. [2023] Zhihan Zhou, Yanrong Ji, Weijian Li, Pratik Dutta, Ramana Davuluri, and Han Liu. Dnabert-2: Efficient foundation model and benchmark for multi-species genome. arXiv preprint arXiv:2306.15006, 2023.
  • Zhou et al. [2024c] Zhihan Zhou, Weimin Wu, Harrison Ho, Jiayi Wang, Lizhen Shi, Ramana V Davuluri, Zhong Wang, and Han Liu. Dnabert-s: Learning species-aware dna embedding with genome foundation models. ArXiv, 2024c.
  • Zhu et al. [2023] Zhenyu Zhu, Francesco Locatello, and Volkan Cevher. Sample complexity bounds for score-matching: Causal discovery and generative modeling. Advances in Neural Information Processing Systems (NeurIPS), 36, 2023.