Learning-to-Cache:
Accelerating Diffusion Transformer via Layer Caching

Xinyin Ma1  Gongfan Fang1 Michael Bi Mi2 Xinchao Wang1
National University of Singapore1  Huawei Technologies Ltd.2
[email protected], [email protected]
Corresponding author
Abstract

Diffusion Transformers have recently demonstrated unprecedented generative capabilities for various tasks. The encouraging results, however, come with the cost of slow inference, since each denoising step requires inference on a transformer model with a large scale of parameters. In this study, we make an interesting and somehow surprising observation: the computation of a large proportion of layers in the diffusion transformer, through introducing a caching mechanism, can be readily removed even without updating the model parameters. In the case of U-ViT-H/2, for example, we may remove up to 93.68% of the computation in the cache steps (46.84% for all steps), with less than 0.01 drop in FID. To achieve this, we introduce a novel scheme, named Learning-to-Cache (L2C), that learns to conduct caching in a dynamic manner for diffusion transformers. Specifically, by leveraging the identical structure of layers in transformers and the sequential nature of diffusion, we explore redundant computations between timesteps by treating each layer as the fundamental unit for caching. To address the challenge of the exponential search space in deep models for identifying layers to cache and remove, we propose a novel differentiable optimization objective. An input-invariant yet timestep-variant router is then optimized, which can finally produce a static computation graph. Experimental results show that L2C largely outperforms samplers such as DDIM and DPM-Solver, alongside prior cache-based methods at the same inference speed.

Refer to caption
Figure 1: (a) Generate 512×\times×512 images using DiT-XL/2, sampled by DDIM with 50 NFEs. (b) Generate 256×\times×256 images using U-ViT-H/2, sampled by DPM-Solver-2 with 50 NFEs.

1 Introduction

In recent years, diffusion models [53, 52, 14] have achieved remarkable performance as powerful generative models for image generation [42, 7]. Among the various backbone designs for diffusion models, transformers [54] have emerged as a strong contender, showing its exceptional capabilities not only in synthesizing high-fidelity images [40, 2] but also in video generation  [32, 6, 3], text-to-speech synthesis [28, 16] and 3D generation [35, 4]. The diffusion transformer, while benefiting greatly from the great property of scalability of the transformer architecture, however, also brings about significant challenges in efficiency, including high deployment costs and slow inference speed.

Since the cost of sampling increases proportionally with the number of timesteps and the model size per timestep, naturally, current methods for increasing the sampling efficiency entail two branches: reducing the sampling steps[50, 14, 29, 1] or reducing the inference cost per step [11, 59]. The methods to reduce the number of timesteps include distilling the trajectory into fewer steps [45, 51, 33], discretizing the reverse-time SDE or the probability flow ODE [50, 63, 31]. Methods in another branch are mainly about compressing the model size [20, 25] or using a low-precision data format [13, 46]. A new method in the dynamic inference of diffusion is a special cache mechanism in the denoising process [34, 56]. These methods leverage the high similarity between the two steps and the special property of U-Net to cache some of the computations, which would be directly used in the next step. Some other dynamic inference methods employ a spectrum of diffusion models and allocate different networks for different steps [57, 38].

Previous approaches, especially those aimed at reducing model size, have predominantly targeted the compression of the U-Net architecture [43]. Our objective is to explore a paradigm for inference acceleration that is more suitable for transformer-based diffusion models. Unlike other architectures, transformers are distinctively composed of several layers with consistent structure. Based on this property, previous compression work on transformers mainly focuses on layer pruning [61] and random layer drop** [10, 41], as optimizing at the layer level tends to achieve higher speedup ratios compared to width pruning [19]. However, for diffusion transformers, we observed that drop** layers without retraining is not feasible. Removing even a few layers significantly degrades image quality (see Section 4.3). This observation highlights that the redundancy among layers at varying depths is not evident in DiT. Therefore, we consider another perspective of redundancy: the redundancy across layers situated at the same depths but occurring at different timesteps.

Motivated by cache-based methods [34, 56, 23], we aim to explore the existence and limitations of layer redundancy between timesteps within the diffusion transformer. A straightforward approach involves an exhaustive search where each layer is either cached or not, resulting in an exponentially growing search space with the depth of the layers. Additionally, heuristic-based layer selection cannot adequately address the mutual dependencies between layers. To overcome these challenges, we designed a framework that makes the problem of layer selection differentiable. Specifically, we interpolate predictions between two adjacent steps. This interpolation spans two extremes: a fast configuration where all layers are cached at the expense of image quality, and a slow configuration where all layers are retained, achieving optimal performance. We then search this interpolation space to identify an optimal caching scheme, optimizing a specialized router. This router is time-dependent but input-invariant, allowing the creation of a static computation graph for inference. We train this router by formulating an optimization problem that does not require updating model parameters, making it both cost-effective and easy to optimize.

Our results indicate that different percentages of layers can be cached in DiT [35] and U-ViT [2]. Notably, for U-ViT-H/2 on ImageNet, approximately 93.68% of layers are cacheable in the cache step, whereas for DiT-XL/2, the cacheable ratio is 47.43%, both with an almost negligible performance loss (ΔΔ\Deltaroman_ΔFID < 0.01). By comparison, with the same acceleration ratio, a sampler with fewer steps would compromise image quality. Our method L2C can significantly outperform the fast sampler, as well as previous cache-based methods. Additionally, we observed distinct sparsity patterns for layers between these two models, suggesting significant behavioral variations between different architecture designs for diffusion transformers.

In summary, our contribution is the proposal of a novel acceleration method, learning-to-cache (L2C), specifically for diffusion transformers. We convert the non-differentiable layer selection problem into a differentiable optimization problem by interpolation, facilitating the learning of layer caching. Our results demonstrate that a large proportion of layers in the diffusion transformer can be cached without compromising performance. Furthermore, our approach significantly outperforms samplers with fewer steps and other cache-based methods. The code is available at https://github.com/horseee/learning-to-cache

2 Related Work

Transformers in Diffusion Models.

Transformer [54] is applied in diffusion models as an alternative to UNet[43]. GenViT[60] integrates the ViT[9] architecture into DDPM. U-ViT [2] employs the long skip connections between shallow and deep layers. DiT [40] shows the scalability of diffusion transformers and is further used as a general architecture for text-to-video generation [3, 32], speech synthesis [28] and 3D generation [4]. [12, 66] further shows that masked modeling can reduce the training cost of diffusion transformers.

Acceleration of Diffusion Models.

Generating images by diffusion models requires several rounds of model evaluation which is time-expensive. Some works focus on reducing the number of sampling steps in a training-free manner. DDIM[50] extends the original DDPM to non-Markovian cases. DPM-Solver[30, 31] further approximates the solution of diffusion ODES by the exponential integrators. EDM[18] finds that the Heun’2 2nd order method provides an excellent tradeoff between truncation error and NFE. More works try to solve either SDEs[53, 17, 8] or ODEs[29, 63, 62] in a more accurate and fast way. Other training-based methods [45, 26] distill and half the sampling steps. [51, 33] learns to map any point on the ODE trajectory to its origin. Another line of work reduces the workload per step. The model per step is compressed by reducing the parameter size [11, 5, 61, 55], using reduced precision [24, 13] and re-design the structure of the diffusion model [59, 25, 65, 20]. In addition to static model inference, dynamic model inference has also been extensively explored within diffusion models, which employs different models for inference at varying steps. [27, 39] switch between different sizes of models in a model zoo. [36] designs a time-dependent exit schedule to skip a subset of parameters. Besides these two branches of work, there is also work concerning denoising diffusion models in parallel, either through iterative optimization[47] or image splitting[22].

Cache in Diffusion Models

Cache [48] is used in computer systems to hold temporarily those portions of contents in the main memory which is believed to be used in a short time. Recently, [34, 56] explores the cache mechanism in diffusion models. Based on the observations that the change of high-level features is typically very small in consecutive steps, they propose to reuse the feature maps. By utilizing the computation flow of U-Net, [34] reuse the high-level features while updating the low-level features. [56, 23] further discovers the better position in U-Net to be cached. [15] proposes to reuse the attention map. [56, 49, 34] adjust the lifetime for each caching features and [56] further scales and shifts the reused features. [64] finds the cross-attention is redundant in the fidelity-improving stage and can be cached. [58] hashes and caches the images rendered from camera positions and diffusion timesteps to improve the efficiency of 3D generative modeling.

3 Method

Refer to caption
Figure 2: Illustration of Learning-to-Cache. When a layer is activated, the calculation proceeds as usual. In contrast, when a layer is disabled, the computation of the non-residual path is bypassed, and the results from the previous step are utilized instead. The router 𝜷𝜷\boldsymbol{\beta}bold_italic_β smoothly controls the transition between two endpoints ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}(\boldsymbol{x}_{s},s)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) and ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}(\boldsymbol{x}_{m},m)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ).

3.1 Preliminary

The forward diffusion process starts at the starting point 𝒙0subscript𝒙0\boldsymbol{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, where 𝒙0subscript𝒙0\boldsymbol{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is sampled from the data distribution q(𝒙0)𝑞subscript𝒙0q(\boldsymbol{x}_{0})italic_q ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) to be learned. 𝒙0subscript𝒙0\boldsymbol{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is degenerated with gradually added Gaussian noise, with:

𝒙tq(𝒙t|𝒙0)=𝒩(𝒙t;αt𝒙0,σt2𝐈)similar-tosubscript𝒙𝑡𝑞conditionalsubscript𝒙𝑡subscript𝒙0𝒩subscript𝒙𝑡subscript𝛼𝑡subscript𝒙0subscriptsuperscript𝜎2𝑡𝐈\boldsymbol{x}_{t}\sim q(\boldsymbol{x}_{t}|\boldsymbol{x}_{0})=\mathcal{N}% \left(\boldsymbol{x}_{t};{\alpha}_{t}\boldsymbol{x}_{0},\sigma^{2}_{t}\mathbf{% I}\right)bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_q ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_I ) (1)

where αtsubscript𝛼𝑡\alpha_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and σtsubscript𝜎𝑡\sigma_{t}italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the noise coefficient. We can quickly sample xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT at arbitrary timestep by reparameterization trick. And for the reverse process, given two timesteps s𝑠sitalic_s and t𝑡titalic_t, where s>0𝑠0s>0italic_s > 0 and t<s𝑡𝑠t<sitalic_t < italic_s, xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is calculated as[30]:

𝒙t=αtαs𝒙sαtλsλteλϵ^θ(𝒙tλ(λ),tλ(λ))dλsubscript𝒙𝑡subscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝛼𝑡superscriptsubscriptsubscript𝜆𝑠subscript𝜆𝑡superscript𝑒𝜆subscript^bold-italic-ϵ𝜃subscript𝒙subscript𝑡𝜆𝜆subscript𝑡𝜆𝜆differential-d𝜆\boldsymbol{x}_{t}=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_{s}-\alpha_{t}% \int_{\lambda_{s}}^{\lambda_{t}}e^{-\lambda}\hat{\boldsymbol{\epsilon}}_{% \theta}\left(\boldsymbol{x}_{t_{\lambda}(\lambda)},t_{\lambda}(\lambda)\right)% \mathrm{d}\lambdabold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT - italic_λ end_POSTSUPERSCRIPT over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_λ ) end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_λ ) ) roman_d italic_λ (2)

where λt=log(αt/σt)subscript𝜆𝑡subscript𝛼𝑡subscript𝜎𝑡\lambda_{t}=\log\left(\alpha_{t}/\sigma_{t}\right)italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_log ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). tλ(λ)subscript𝑡𝜆𝜆t_{\lambda}(\lambda)italic_t start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_λ ) is the inverse function of λtsubscript𝜆𝑡\lambda_{t}italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that satisfies tλ(λt)=tsubscript𝑡𝜆subscript𝜆𝑡𝑡t_{\lambda}(\lambda_{t})=titalic_t start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_t. ϵθ()subscriptbold-italic-ϵ𝜃\boldsymbol{\epsilon}_{\theta}\left(\cdot\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) often represents the learned model, which, in our case, is the diffusion transformer. Previous methods show that this integral term can be approximated by adopting Taylor expansion at λssubscript𝜆𝑠\lambda_{s}italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, adopting the first-order [50] or higher-order approximation of this [30]. Take the first-order one as an example, the update of 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT would be:

𝒙t=αtαs𝒙sσt(eλtλs1)ϵθ(𝒙s,s)subscript𝒙𝑡subscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑠1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{x}_{t}=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_{s}-\sigma_{t}% \left(e^{\lambda_{t}-\lambda_{s}}-1\right)\boldsymbol{\epsilon}_{\theta}\left(% \boldsymbol{x}_{s},s\right)bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) (3)

3.2 Approximating ϵθ()subscriptbold-italic-ϵ𝜃\boldsymbol{\epsilon}_{\theta}(\cdot)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ )with a lightweight substitute

The question falls into how to efficiently calculate the term λsλteλϵ^θ(𝒙tλ(λ),tλ(λ))dλsuperscriptsubscriptsubscript𝜆𝑠subscript𝜆𝑡superscript𝑒𝜆subscript^bold-italic-ϵ𝜃subscript𝒙subscript𝑡𝜆𝜆subscript𝑡𝜆𝜆differential-d𝜆\int_{\lambda_{s}}^{\lambda_{t}}e^{-\lambda}\hat{\boldsymbol{\epsilon}}_{% \theta}\left(\boldsymbol{x}_{t_{\lambda}(\lambda)},{t_{\lambda}(\lambda)}% \right)\mathrm{d}\lambda∫ start_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT - italic_λ end_POSTSUPERSCRIPT over^ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_λ ) end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_λ ) ) roman_d italic_λ. Our core idea is that we want to keep more updates between s𝑠sitalic_s and t𝑡titalic_t while the overall inference time would not increase too much. Suppose that we have three timesteps: s𝑠sitalic_s and t𝑡titalic_t and one step m𝑚mitalic_m between s𝑠sitalic_s and t𝑡titalic_t, the calculation of 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, in the case of Eq.3, would become:

𝒙t=αtαm𝒙mσt(eλtλm1)ϵθ(𝒙m,m), where 𝒙m=αmαs𝒙sσm(eλmλs1)ϵθ(𝒙s,s)subscript𝒙𝑡subscript𝛼𝑡subscript𝛼𝑚subscript𝒙𝑚subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑚1subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚, where subscript𝒙𝑚subscript𝛼𝑚subscript𝛼𝑠subscript𝒙𝑠subscript𝜎𝑚superscript𝑒subscript𝜆𝑚subscript𝜆𝑠1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{x}_{t}=\frac{\alpha_{t}}{\alpha_{m}}\boldsymbol{x}_{m}-\sigma_{t}% \left(e^{\lambda_{t}-\lambda_{m}}-1\right)\boldsymbol{\epsilon}_{\theta}\left(% \boldsymbol{x}_{m},m\right)\text{, where }\boldsymbol{x}_{m}=\frac{\alpha_{m}}% {\alpha_{s}}\boldsymbol{x}_{s}-\sigma_{m}\left(e^{\lambda_{m}-\lambda_{s}}-1% \right)\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) , where bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) (4)

If we directly set ϵθ(𝒙m,m)=ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)=\boldsymbol{% \epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) = bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ), it would be equivalent to the results in Equation 3 if we take a step directly from s𝑠sitalic_s to t𝑡titalic_t (see the derivation in Appendix A.1). This approach results in faster computation, as it eliminates the need to compute ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ); however, it compromises the quality of the resulting image. In contrast, another time-consuming but optimal way is to calculate ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) as usual, which necessitates a full model evaluation but yields superior image quality.

Recognizing that ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) represents a rapid yet suboptimal solution and ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) represents a slower but optimal solution when calculating 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, we want to find a model ϵ~(𝒙m,m)bold-~bold-italic-ϵsubscript𝒙𝑚𝑚\boldsymbol{\tilde{\epsilon}}(\boldsymbol{x}_{m},m)overbold_~ start_ARG bold_italic_ϵ end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ), which is the interpolation of these two models. We first define the interpolation as follows:

ϵ~θ(𝒙m,m;𝜷)=(ϵθ(𝒙s,s),ϵθ(𝒙m,m),𝜷)subscript~bold-italic-ϵ𝜃subscript𝒙𝑚𝑚𝜷subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚𝜷\tilde{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{m},m;\boldsymbol{% \beta}\right)={\mathcal{I}}(\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}% _{s},s\right),\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right),% \boldsymbol{\beta})over~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ; bold_italic_β ) = caligraphic_I ( bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) , bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) , bold_italic_β ) (5)

where ϵ~θ(𝒙m,m)subscript~bold-italic-ϵ𝜃subscript𝒙𝑚𝑚\tilde{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{m},m\right)over~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) is controlled by a set of variables 𝜷𝜷\boldsymbol{\beta}bold_italic_β, functioning as a slider that can smoothly transition between the two endpoints ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) and ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ). ϵ~𝜽(𝒙m,m)subscriptbold-~bold-italic-ϵ𝜽subscript𝒙𝑚𝑚\boldsymbol{\tilde{\epsilon}_{\theta}}(\boldsymbol{x}_{m},m)overbold_~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) needs to meet two criteria: it should approximate the output of ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) and be faster for inference compared to ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ). By creating the interpolation \mathcal{I}caligraphic_I, we generate a large collection of models, allowing us to search within this set to find if there exists an ϵ~𝜽(𝒙m,m)subscriptbold-~bold-italic-ϵ𝜽subscript𝒙𝑚𝑚\boldsymbol{\tilde{\epsilon}_{\theta}}(\boldsymbol{x}_{m},m)overbold_~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) that satisfies our requirements.

3.3 Caching the Layer: A Feasible Choice for the Interpolation \mathcal{I}caligraphic_I

In this section, we specifically define an interpolation \mathcal{I}caligraphic_I and explore the possibility of the existence of ϵ~θ(𝒙m,m)subscript~bold-italic-ϵ𝜃subscript𝒙𝑚𝑚\tilde{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{m},m\right)over~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) within it. Given the transformer architecture, we propose an interpolation schema by leveraging the layers of the transformer model. Here we take the computation of DiT[40] as an illustrative example. The transformer model can be decomposed into a sequence of basic layers Li(h,t)i=1Dsubscript𝐿𝑖superscriptsubscript𝑡𝑖1𝐷{L_{i}(h,t)}_{i=1}^{D}italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h , italic_t ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, where Li(h,t)=h+g(t)fi(h,t)subscript𝐿𝑖𝑡𝑔𝑡subscript𝑓𝑖𝑡L_{i}(h,t)=h+g(t)*f_{i}(h,t)italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h , italic_t ) = italic_h + italic_g ( italic_t ) ∗ italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h , italic_t ), consisting of a residual connection. Here, hhitalic_h is the input feature, and D𝐷Ditalic_D denotes the depth of the model. t𝑡titalic_t is the time condition. fi(h,t)subscript𝑓𝑖𝑡f_{i}(h,t)italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h , italic_t ) can represent either a multi-head self-attention (MHSA) block or a pointwise feedforward block, and g(t)𝑔𝑡g(t)italic_g ( italic_t ) is a time-conditioned scalar. We omit the condition t𝑡titalic_t in fi(h,t)subscript𝑓𝑖𝑡f_{i}(h,t)italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h , italic_t ) for simplicity. Then we can construct a linear interpolation within the layers, and this interpolation of layer satisfies the model interpolation \mathcal{I}caligraphic_I (See Appendix A.2):

L~i(him,m;αi,βi)=him(1αi)(himhis)+g(m)(βif(him)+(1βi)f(his))subscript~𝐿𝑖superscriptsubscript𝑖𝑚𝑚subscript𝛼𝑖subscript𝛽𝑖superscriptsubscript𝑖𝑚1subscript𝛼𝑖superscriptsubscript𝑖𝑚superscriptsubscript𝑖𝑠𝑔𝑚subscript𝛽𝑖𝑓superscriptsubscript𝑖𝑚1subscript𝛽𝑖𝑓superscriptsubscript𝑖𝑠\displaystyle\tilde{L}_{i}(h_{i}^{m},m;\alpha_{i},\beta_{i})=h_{i}^{m}-(1-% \alpha_{i})\cdot(h_{i}^{m}-h_{i}^{s})+g(m)\left(\beta_{i}\cdot f(h_{i}^{m})+(1% -\beta_{i})\cdot f(h_{i}^{s})\right)over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_m ; italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT - ( 1 - italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT - italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) + italic_g ( italic_m ) ( italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) + ( 1 - italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) ) (6)

where hissuperscriptsubscript𝑖𝑠h_{i}^{s}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT and himsuperscriptsubscript𝑖𝑚h_{i}^{m}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT is the input to the block Lisubscript𝐿𝑖L_{i}italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT at timestep s𝑠sitalic_s and m𝑚mitalic_m respectively. βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a coefficient in layer i𝑖iitalic_i to control the proximity to f(him)𝑓superscriptsubscript𝑖𝑚f(h_{i}^{m})italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) or f(xis)𝑓superscriptsubscript𝑥𝑖𝑠f(x_{i}^{s})italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) and αisubscript𝛼𝑖\alpha_{i}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is to used as an control for the input. Both of these variables are constrained within the range [0,1]01[0,1][ 0 , 1 ].

This interpolation provides a special mechanism for inference. If βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in layer i𝑖iitalic_i is set to 0, the output can be directly taken from the layer in the previous timestep, allowing the computation cost in this layer to be skipped. Non-zero βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT would trigger the original computation of layer i𝑖iitalic_i. A discretized βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT can be seen as a router, which selects the layers to be activated or disabled. And for αisubscript𝛼𝑖\alpha_{i}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, it can be set to any value since there is almost no computation cost for a combination of himsuperscriptsubscript𝑖𝑚h_{i}^{m}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT and hissuperscriptsubscript𝑖𝑠h_{i}^{s}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT and we choose αi=0subscript𝛼𝑖0\alpha_{i}=0italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0. By setting more βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in different layers to 0, the acceleration ratio can be cumulative. Therefore, we can calculate the total computational cost based on the number of non-zero βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, and our goal ϵ~𝜽(𝒙m,m)subscriptbold-~bold-italic-ϵ𝜽subscript𝒙𝑚𝑚\boldsymbol{\tilde{\epsilon}_{\theta}}(\boldsymbol{x}_{m},m)overbold_~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) can be interpreted as finding as many zeros in {βi}i=1Dsuperscriptsubscriptsubscript𝛽𝑖𝑖1𝐷\{\beta_{i}\}_{i=1}^{D}{ italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT as possible with the minimal approximation error between ϵ~θ(𝒙m,m)subscript~bold-italic-ϵ𝜃subscript𝒙𝑚𝑚\tilde{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{m},m\right)over~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) and ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ).

One key observation.

One greedy way for finding the βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in each layer is taking the approximation error of each layer into account:

E=L~()L()22=(1βi)|g(m)|f(him)f(his)22𝐸superscriptsubscriptnorm~𝐿𝐿221subscript𝛽𝑖𝑔𝑚subscriptsuperscriptnorm𝑓superscriptsubscript𝑖𝑚𝑓superscriptsubscript𝑖𝑠22E=||\tilde{L}(\cdot)-L(\cdot)||_{2}^{2}=(1-\beta_{i})\cdot|g(m)|\cdot||f(h_{i}% ^{m})-f(h_{i}^{s})||^{2}_{2}italic_E = | | over~ start_ARG italic_L end_ARG ( ⋅ ) - italic_L ( ⋅ ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = ( 1 - italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ | italic_g ( italic_m ) | ⋅ | | italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) - italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT (7)

and taking βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in those layer with smallest |g(m)|f(him)f(his)22𝑔𝑚subscriptsuperscriptnorm𝑓superscriptsubscript𝑖𝑚𝑓superscriptsubscript𝑖𝑠22|g(m)|\cdot||f(h_{i}^{m})-f(h_{i}^{s})||^{2}_{2}| italic_g ( italic_m ) | ⋅ | | italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) - italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT to be 0. In Figure 3, we analyze f(him)f(his)22subscriptsuperscriptnorm𝑓superscriptsubscript𝑖𝑚𝑓superscriptsubscript𝑖𝑠22||f(h_{i}^{m})-f(h_{i}^{s})||^{2}_{2}| | italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) - italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT in two types of models: DiT and U-ViT. We find that performance varies significantly across different timesteps, even at the same layer. Particularly in the DiT model, the error is markedly higher in the later steps compared to the early denoising steps. Additionally, the performance of multi-head self-attention differs substantially from that of feedforward layers. Based on this, we assign each timestep with its own {βi}i=1Dsuperscriptsubscriptsubscript𝛽𝑖𝑖1𝐷\{\beta_{i}\}_{i=1}^{D}{ italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. Thus, 𝜷𝜷\boldsymbol{\beta}bold_italic_β becomes time-variant, where 𝜷={βiji=1,2,,T;j=1,2,,D}𝜷conditional-setsubscript𝛽𝑖𝑗formulae-sequence𝑖12𝑇𝑗12𝐷\boldsymbol{\beta}=\left\{\beta_{ij}\mid i=1,2,\ldots,T;j=1,2,\ldots,D\right\}bold_italic_β = { italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ∣ italic_i = 1 , 2 , … , italic_T ; italic_j = 1 , 2 , … , italic_D } and T𝑇Titalic_T is the total denoising steps.

In addition, we directly use this metric as the criterion for βijsubscript𝛽𝑖𝑗\beta_{ij}italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT and employ it during inference. From the experimental results in 4, we observe that it cannot effectively handle a combination of layers. This limitation arises because the approximation error for each layer is influenced by changes in the preceding layer. However, exhaustively evaluating all possible configurations is impractical, as the number of trials increases exponentially with the depth of the model.

Refer to caption
Refer to caption
Figure 3: Approximation Error for DiT and U-ViT in different timesteps and different layers

3.4 Learning to Cache

To address this, we propose the following method: Learning to Cache. Recall that our goal is to find a ϵ~θ(𝒙m,m)subscript~bold-italic-ϵ𝜃subscript𝒙𝑚𝑚\tilde{\boldsymbol{\epsilon}}_{\theta}\left(\boldsymbol{x}_{m},m\right)over~ start_ARG bold_italic_ϵ end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) that is (1) as close as ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) and (2) with minimal computation cost. We can reformulate this as an optimization problem as:

argmin𝜷ϵ~(𝒙m,m;𝜷)ϵ(𝒙m,m)22s.t. i=1Dδβij1Csubscript𝜷superscriptsubscriptnorm~italic-ϵsubscript𝒙𝑚𝑚𝜷italic-ϵsubscript𝒙𝑚𝑚22s.t. superscriptsubscript𝑖1𝐷subscript𝛿subscript𝛽𝑖𝑗1𝐶\arg\min_{\boldsymbol{\beta}}||\tilde{\epsilon}(\boldsymbol{x}_{m},m;% \boldsymbol{\beta})-\epsilon(\boldsymbol{x}_{m},m)||_{2}^{2}\quad\text{s.t. }% \sum_{i=1}^{D}{\delta_{\beta_{ij}1}}\leq Croman_arg roman_min start_POSTSUBSCRIPT bold_italic_β end_POSTSUBSCRIPT | | over~ start_ARG italic_ϵ end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ; bold_italic_β ) - italic_ϵ ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT s.t. ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_C (8)

where C𝐶Citalic_C is the constraint for the total cost. δβij1subscript𝛿subscript𝛽𝑖𝑗1\delta_{\beta_{ij}1}italic_δ start_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is the Kronecker delta function, which is 1 if βij=1subscript𝛽𝑖𝑗1\beta_{ij}=1italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = 1. Though βijsubscript𝛽𝑖𝑗\beta_{ij}italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT in the final solution needs to be discrete, βijsubscript𝛽𝑖𝑗\beta_{ij}italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT is designed to be continuous to make the computation differentiable when optimized. And when inference, a threshold θ𝜃\thetaitalic_θ would be set to discretize the βijsubscript𝛽𝑖𝑗\beta_{ij}italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT to be either 0 or 1, where βijsubscript𝛽𝑖𝑗\beta_{ij}italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT turned to become a router. The only trained variables in our algorithm are 𝜷𝜷\boldsymbol{\beta}bold_italic_β. Thus, the parameters in the diffusion model would remain unchanged. With the help of Lagrange duality to transform the optimization problem into an unconstrained one, the loss would be:

(ϵ~,ϵ,𝒙m,m;𝜷)=ϵ~(𝒙m,m;𝜷)ϵ(𝒙m,m)22+λi=1Dβij~bold-italic-ϵbold-italic-ϵsubscript𝒙𝑚𝑚𝜷superscriptsubscriptnorm~italic-ϵsubscript𝒙𝑚𝑚𝜷italic-ϵsubscript𝒙𝑚𝑚22𝜆superscriptsubscript𝑖1𝐷subscript𝛽𝑖𝑗\mathcal{L}(\tilde{\boldsymbol{\epsilon}},\boldsymbol{\epsilon},\boldsymbol{x}% _{m},m;\boldsymbol{\beta})=||\tilde{\epsilon}(\boldsymbol{x}_{m},m;\boldsymbol% {\beta})-\epsilon(\boldsymbol{x}_{m},m)||_{2}^{2}+\lambda\cdot\sum_{i=1}^{D}{% \beta_{ij}}caligraphic_L ( over~ start_ARG bold_italic_ϵ end_ARG , bold_italic_ϵ , bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ; bold_italic_β ) = | | over~ start_ARG italic_ϵ end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ; bold_italic_β ) - italic_ϵ ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_λ ⋅ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT (9)

where λ𝜆\lambdaitalic_λ is the Lagrangian multiplier that governs the regularization. We show the algorithm for training and inference in Algorithm 1 and 2. In these algorithms, for simplicity, the image encoder and decoder are omitted. Additionally, to ensure 𝜷𝜷\boldsymbol{\beta}bold_italic_β remains within the range [0,1]01[0,1][ 0 , 1 ], a sigmoid operation is performed before 𝜷𝜷\boldsymbol{\beta}bold_italic_β is passed into the model.

Algorithm 1 Training
1:  Input: Data distribution p(𝒙0)𝑝subscript𝒙0p(\boldsymbol{x}_{0})italic_p ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), diffusion model ϵθ()subscriptitalic-ϵ𝜃\epsilon_{\theta}(\cdot)italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ), learning rate η𝜂\etaitalic_η, ODE solver Ψ()Ψ\Psi(\cdot)roman_Ψ ( ⋅ ), total steps T𝑇Titalic_T and the step schedule {ti}i=1Tsuperscriptsubscriptsubscript𝑡𝑖𝑖1𝑇\{t_{i}\}_{i=1}^{T}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT in Ψ()Ψ\Psi(\cdot)roman_Ψ ( ⋅ )
2:  𝜷𝒩(0,1)similar-to𝜷𝒩01\boldsymbol{\beta}\sim\mathcal{N}(0,1)bold_italic_β ∼ caligraphic_N ( 0 , 1 )
3:  repeat
4:     𝒙0p(𝒙0)similar-tosubscript𝒙0𝑝subscript𝒙0\boldsymbol{x}_{0}\sim p(\boldsymbol{x}_{0})bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), n𝒰[2,T]similar-to𝑛𝒰2𝑇n\sim\mathcal{U}[2,T]italic_n ∼ caligraphic_U [ 2 , italic_T ]
5:     stn𝑠subscript𝑡𝑛s\leftarrow t_{n}italic_s ← italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, mtn1𝑚subscript𝑡𝑛1m\leftarrow t_{n-1}italic_m ← italic_t start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT
6:     𝒙s𝒩(𝒙s;αs𝒙0,σs2𝐈)similar-tosubscript𝒙𝑠𝒩subscript𝒙𝑠subscript𝛼𝑠subscript𝒙0subscriptsuperscript𝜎2𝑠𝐈\boldsymbol{x}_{s}\sim\mathcal{N}\left(\boldsymbol{x}_{s};{\alpha}_{s}% \boldsymbol{x}_{0},\sigma^{2}_{s}\mathbf{I}\right)bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT bold_I )
7:     𝒙mΨ(ϵθ(𝒙s,s),s,m)subscript𝒙𝑚Ψsubscriptitalic-ϵ𝜃subscript𝒙𝑠𝑠𝑠𝑚\boldsymbol{x}_{m}\leftarrow\Psi(\epsilon_{\theta}(\boldsymbol{x}_{s},s),s,m)bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← roman_Ψ ( italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) , italic_s , italic_m ) and cache {f()}i=1Dsuperscriptsubscript𝑓𝑖1𝐷\{f(\cdot)\}_{i=1}^{D}{ italic_f ( ⋅ ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT when calculate ϵθ(𝒙s,s)subscriptitalic-ϵ𝜃subscript𝒙𝑠𝑠\epsilon_{\theta}(\boldsymbol{x}_{s},s)italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s )
8:     βmSigmoid(𝜷m)subscript𝛽𝑚Sigmoidsubscript𝜷𝑚\beta_{m}\leftarrow\operatorname{Sigmoid}(\boldsymbol{\beta}_{m})italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← roman_Sigmoid ( bold_italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT )
9:     Calculate ϵ~(𝒙m,m;βm)~italic-ϵsubscript𝒙𝑚𝑚subscript𝛽𝑚\tilde{\epsilon}(\boldsymbol{x}_{m},m;\beta_{m})over~ start_ARG italic_ϵ end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ; italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) by Eq.6
10:     ϵ~(xm,m)ϵθ(xm,m)22+λβmsubscriptsuperscriptnorm~italic-ϵsubscript𝑥𝑚𝑚subscriptitalic-ϵ𝜃subscript𝑥𝑚𝑚22𝜆subscript𝛽𝑚\mathcal{L}\leftarrow||\tilde{\epsilon}(x_{m},m)-\epsilon_{\theta}(x_{m},m)||^% {2}_{2}+\lambda\sum{\beta_{m}}caligraphic_L ← | | over~ start_ARG italic_ϵ end_ARG ( italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) - italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_λ ∑ italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT
11:     𝜷m𝜷mη𝜷msubscript𝜷𝑚subscript𝜷𝑚𝜂subscriptsubscript𝜷𝑚\boldsymbol{\beta}_{m}\leftarrow\boldsymbol{\beta}_{m}-\eta\nabla_{\boldsymbol% {\beta}_{m}}\mathcal{L}bold_italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ← bold_italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_η ∇ start_POSTSUBSCRIPT bold_italic_β start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L
12:  until converged

Algorithm 2 Sampling
1:  Input: diffusion model ϵθ()subscriptitalic-ϵ𝜃\epsilon_{\theta}(\cdot)italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ), router 𝜷𝜷\boldsymbol{\beta}bold_italic_β, ODE solver Ψ()Ψ\Psi(\cdot)roman_Ψ ( ⋅ ), threshold θ𝜃\thetaitalic_θ, total steps T𝑇Titalic_T and the step schedule {ti}i=1Tsuperscriptsubscriptsubscript𝑡𝑖𝑖1𝑇\{t_{i}\}_{i=1}^{T}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT in Ψ()Ψ\Psi(\cdot)roman_Ψ ( ⋅ )
2:  𝒙T𝒩(𝟎,𝐈)similar-tosubscript𝒙𝑇𝒩0𝐈\boldsymbol{x}_{T}\sim\mathcal{N}(\mathbf{0},\mathbf{I})bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , bold_I )
3:  for n=T,,1𝑛𝑇1n=T,\dotsc,1italic_n = italic_T , … , 1 do
4:     h1tn𝒙nsuperscriptsubscript1subscript𝑡𝑛subscript𝒙𝑛h_{1}^{t_{n}}\leftarrow\boldsymbol{x}_{n}italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ← bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT
5:     for i=1,,D𝑖1𝐷i=1,\dotsc,Ditalic_i = 1 , … , italic_D do
6:        βi1 if Sigmoid(𝜷tni)>θ else 0subscript𝛽𝑖1 if Sigmoidsubscript𝜷subscript𝑡𝑛𝑖𝜃 else 0\beta_{i}\leftarrow 1\text{ if }\operatorname{Sigmoid}(\boldsymbol{\beta}_{{t_% {n}}i})>\theta\text{ else }0italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← 1 if roman_Sigmoid ( bold_italic_β start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) > italic_θ else 0
7:        hi+1tnL~i(hitn,tn;0,βi)superscriptsubscript𝑖1subscript𝑡𝑛subscript~𝐿𝑖superscriptsubscript𝑖subscript𝑡𝑛subscript𝑡𝑛0subscript𝛽𝑖h_{i+1}^{t_{n}}\leftarrow\tilde{L}_{i}(h_{i}^{t_{n}},{t_{n}};0,\beta_{i})italic_h start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ← over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ; 0 , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) by Eq.6
8:     end for
9:     ϵ~(xn,tn)hD+1tn~italic-ϵsubscript𝑥𝑛subscript𝑡𝑛superscriptsubscript𝐷1subscript𝑡𝑛\tilde{\epsilon}(x_{n},t_{n})\leftarrow h_{D+1}^{t_{n}}over~ start_ARG italic_ϵ end_ARG ( italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ← italic_h start_POSTSUBSCRIPT italic_D + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
10:     𝒙n1Ψ(ϵ~(𝒙n,tn),tn,tn1)subscript𝒙𝑛1Ψ~italic-ϵsubscript𝒙𝑛subscript𝑡𝑛subscript𝑡𝑛subscript𝑡𝑛1\boldsymbol{x}_{n-1}\leftarrow\Psi(\tilde{\epsilon}(\boldsymbol{x}_{n},t_{n}),% t_{n},t_{n-1})bold_italic_x start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT ← roman_Ψ ( over~ start_ARG italic_ϵ end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) , italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT )
11:  end for
12:  return 𝒙0subscript𝒙0\boldsymbol{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
Table 1: Accelerating image generation on ImageNet for the DiT model family.
Methods NFE MACs(T) Latency(s) Speedup IS\uparrow FID\downarrow sFID\downarrow Precision\uparrow Recall\uparrow
DiT-XL/2 (ImageNet 256×256) (cfg=1.5)
DDPM 250 28.61 36.55 - 280.1 2.27 4.54 82.73 57.95
DDIM 250 28.61 36.45 - 243.4 2.14 4.55 80.70 60.57
DDIM 50 5.72 7.25 1.00×\times× 238.6 2.26 4.29 80.16 59.89
DDIM 40 4.57 5.82 1.24×\times× 239.8 2.39 4.28 80.36 59.13
Ours 50 4.36 5.57 1.30×\times× 244.1 2.27 4.23 80.94 58.76
DDIM 20 2.29 2.87 1.00×\times× 223.5 3.48 4.89 78.76 57.07
DDIM 16 1.83 2.30 1.25×\times× 210.9 4.68 5.71 76.78 56.20
Ours 20 1.78 2.26 1.27×\times× 227.0 3.46 4.64 79.15 55.62
DDIM 10 1.14 1.43 1.00×\times× 158.3 12.38 11.22 66.78 52.82
DDIM 9 1.03 1.29 1.11×\times× 140.9 16.57 14.21 62.28 49.98
Ours 10 1.04 1.30 1.10×\times× 156.3 12.79 10.42 66.21 52.15
DiT-XL/2 (ImageNet 512×512) (cfg=1.5)
DDIM 50 22.85 37.73 1.00×\times× 204.1 3.28 4.50 83.33 54.80
DDIM 30 13.71 22.51 1.68×\times× 198.3 3.85 4.92 83.01 56.00
Ours 50 14.19 22.57 1.67×\times× 202.1 3.69 5.03 82.90 54.60
DiT-L/2 (ImageNet 256×256) (cfg=1.5)
DDIM 50 3.88 5.06 1.00×\times× 167.6 4.82 4.40 78.72 54.66
DDIM 40 3.10 4.06 1.25×\times× 168.2 4.99 4.43 79.01 54.71
Ours 50 2.95 4.01 1.26×\times× 168.3 4.82 4.41 78.97 54.73
DDIM 20 1.55 2.01 1.00×\times× 160.16 6.45 5.26 77.13 53.65
DDIM 16 1.24 1.63 1.23×\times× 151.70 7.91 6.24 75.93 51.71
Ours 20 1.20 1.60 1.26×\times× 160.53 6.55 5.08 77.47 52.22
Table 2: Results with U-ViT-H/2 on ImageNet dataset. The resolution here is 256×\times×256. We adopt the DPM-Solver-2, which has 2 function evaluations per step. The total NFE (instead of steps) is reported below. Guidance strength is set to 0.4.
Methods NFE MACs Latency Speedup FID\downarrow NFE MACs Latency Speedup FID\downarrow
DPM-Solver 50 6.44 19.37 1.00×\times× 2.3728 20 2.58 7.69 1.00×\times× 2.5739
DPM-Solver 30 3.86 11.55 1.68×\times× 2.4644 16 2.06 6.08 1.26×\times× 2.7005
Ours 50 3.79 11.16 1.74×\times× 2.3625 20 1.92 5.64 1.35×\times× 2.5809

4 Experiments

4.1 Experimental Setup

Models and Datasets.

We explore our methods on two commonly used transformer architectures in diffusion models: DiT [40] and U-ViT [2]. Specifically, we use DiT-XL/2 (256×\times×256), DiT-XL/2 (512×\times×512), DiT-L/2 and U-ViT-H/2. Except for DiT-L/2, we use the officially released models. We trained a DiT-L/2 for one million steps, which is used to investigate if layer redundancy exists in smaller models that may not be fully converged. Most of the results are presented under the resolution 256×\times×256 and we also show the results on models that generate high resolution 512×\times×512 images.

Implementations.

Since the parameters of the diffusion model would not be updated, the only parameters that require optimization are 𝜷𝜷\boldsymbol{\beta}bold_italic_β, resulting in a very limited number of variables. For example, for DiT-XL-2 with 20 denoising steps, the number of trainable variables is 560. We take the training set of ImageNet to train β𝛽\betaitalic_β for 1 epoch. The learning rate is set to 0.01 and AdamW optimizer is used to optimize β𝛽\betaitalic_β. The training is conducted upon 8 A5000 GPUs with a global batch size equal to 64. To train with classifier-free guidance, we randomly drop some labels and assign a null token to the label. The drop** rates for labels follow the original training pipeline.

Evaluation.

We tested our method upon two samplers, DDIM[50] and DPM-Solver[30], with sampling steps from 10 to 50. For the DiT model, we use the DDIM sampler. And for U-ViT, we use the DPM-Solver-2. All the experiments here use classifier-free guidance. To evaluate the image quality, 50k images are generated per trial. We measure the image quality with Frechet Inception Distance(FID)[37], sFID[37], Inception Score[44], Precision and Recall[21]. Besides, we reported the total MACs and the latency to make a comparison of the acceleration ratio. The MACs is evaluated using pytorch-OpCounter111https://github.com/Lyken17/pytorch-OpCounter, and the latency is tested when generating a batch of images(8 images) with classifier-free guidance on a single A5000, which we conducted five tests and took the average.

4.2 Main Results

We present the results of DiT in Tables 1 and 2, comparing our algorithms with samplers of comparable inference speed. Our method requires more denoising steps, but each step takes less average time. In contrast, samplers require fewer steps, but each step takes more time. Our experiments demonstrate that our methods significantly outperform DDIM and DPM-Solver. For instance, with the 20-step DDIM on DiT-XL/2, our method achieves an FID of 3.46, nearly identical to the unaccelerated one. In comparison, the DDIM achieves an FID of 4.68. When generating high-resolution images, sampling with fewer steps, or using a relatively smaller model, our method still outperforms baselines. However, we observe that achieving nearly lossless compression under these conditions is challenging. We argue that this difficulty arises because layer redundancy is less apparent in these scenarios.

Refer to caption
Refer to caption
Figure 4: Speed-Quality Tradeoff for DiT-XL/2 and U-ViT-H/2 with 20 denosing steps as the basis.

Quality-Latency Tradeoff.

We show the trade-off curve between FID and Latency in Figure 4. These figures offer a more comprehensive comparison with two types of baselines: (1) Heuristic Methods for Selecting Layers. We designed several methods for selecting layers to cache, including rule-based approaches such as caching from top to bottom or from bottom to top, randomly selecting layers, and metric-based selection as described in Eq.7. We found that when the dependency between layers must be considered, they fail to select the optimal layers, leading to a degradation in image quality. In contrast, our method consistently achieves improved quality across various acceleration ratios. (2) Sampler with fewer steps. Our method significantly outperforms DDIM and DPM-Solver, as evidenced by the detailed comparison provided.

Mamimum Cacheable Layers for diffusion transformer.

From the trade-off curve, we found that there exists an upper limit for the number of cacheable layers. Below this limit, image quality remains almost unaffected, as indicated by a FID degradation of less than 0.01. This limit is detailed in Table 4. Notably, caching does not occur at every step: step s𝑠sitalic_s involves full model inference, while only step m𝑚mitalic_m caches layers. With a significant proportion of layers can be cached and the computation of these layers to be saved, notable differences emerge between the U-ViT and DiT models. For instance, in U-ViT, up to 94% of layers can be discarded for the cache step during the denoising process, whereas this proportion is considerably lower for DiT. Furthermore, we observed that the cacheable ratios for FFN and MHSA vary.

Comparison with other cache-based methods

We also compared our method with other cache-based methods. Notably, previous cache-based methods are strongly coupled to the U-Net structure and cannot be applied to models without the U-structure, such as DiT. To ensure a fair comparison, we selected U-ViT, which incorporates both the U-structure and transformers, to implement these methods as baselines alongside our method. Table 4 presents the comparison results. The findings demonstrate that our method achieves better quality than the baselines.

Table 3: Comparison with other cache-based method on U-ViT
Methods NFE Latency Speedup FID\downarrow
DPM-Solver 20 7.69 1.00×\times× 2.57
DeepCache[34] 20 4.68 1.64×\times× 2.70
Ours 20 4.62 1.67×\times× 2.64
Faster Diffusion[23] 20 5.95 1.29×\times× 2.82
Ours 20 5.93 1.30×\times× 2.57
Table 4: Maximum cacheable layers for DiT and U-ViT with different steps.
Model DiT-XL/2 U-ViT-H/2
NFE 50 20 50 20
Remove Ratio 47.43% 44.29% 93.68% 63.79%
FFN Remove Ratio 47.85% 44.64% 94.11% 60.54%
MHSA Remove Ratio 47.00% 43.93% 93.25% 67.05%

4.3 Analysis

The Learned Pattern of 𝜷𝜷\boldsymbol{\beta}bold_italic_β

We present the learned pattern in Figure 6. The two different architectures produce distinct patterns. For U-ViT, the entire middle section is almost entirely cacheable, allowing it to be replaced with the results from the previous step’s calculations. However, the computations at both ends of the model are crucial and cannot be discarded. This observation explains why DeepCache outperforms faster-diffusion on U-ViT, as the learned patterns resemble the manually designed approach of DeepCache. However, this phenomenon is not clearly observed in DiT-XL. Additionally, we found a consistent tendency across models to retain more computation in the later stages while discarding calculations in the earlier stages. This observation aligns with our findings in Figure 3. When comparing the impact of different steps within the same layer, removing parts with smaller timestep has a greater effect on the changes in the output.

Refer to caption
Refer to caption
Figure 5: Learned Router 𝜷𝜷\boldsymbol{\beta}bold_italic_β for DiT-XL/2 (Top) and U-ViT-H/2 (Bottom).
Refer to caption
Figure 6: Effect of threshold θ𝜃\thetaitalic_θ on the image quality.

Comparison between Layer Cache and Layer Dropout

Layer dropout involves directly removing fi()subscript𝑓𝑖f_{i}(\cdot)italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( ⋅ ), retaining only the computation in the skip path. We compare our method with layer dropout, where the layers are either randomly dropped or optimized using our algorithm (named Learn-to-Drop). The results, presented in Table 5, indicate that layer caching significantly outperforms layer dropout. Interestingly, when we learn the layers to be dropped, the models still produce acceptable images, although the quality is not as high. Illustrative examples are provided in Appendix B.2.

Table 5: Comparison with Layer Dropout
Methods Remove Ratio Latency(s) Speedup IS\uparrow FID\downarrow sFID\downarrow Precision\uparrow Recall\uparrow
Random Drop 170/560 2.439 1.18×\times× 3.36 277.42 171.83 1.23 0.24
Learn-to-Drop 179/560 2.421 1.19×\times× 113.93 17.35 28.46 60.25 52.68
Learn-to-Cache 176/560 2.438 1.18×\times× 226.13 3.47 4.58 79.19 56.47

Choice of threshold

We investigated the effect of different thresholds on the image quality. Results are shown in Figure 6, where the model here is trained with six different λ𝜆\lambdaitalic_λ (corresponding to 6 points on one curve). We show the effect of different λ𝜆\lambdaitalic_λ in Appendix B.3. Our results reveal that for higher acceleration ratios, a larger threshold improves image quality. Conversely, for lower acceleration ratios, a smaller threshold is more effective. These also findings suggest that ranking layers by importance is not a reliable approach, since the selection of layers does not follow a strict sequential order. Otherwise, one threshold would win all.

5 Limitation

The primary limitation of this work arises from its dependence on the trained diffusion models. For instance, when applied to DiT-XL/2 at a resolution of 512, our method encounters a slight drop in FID. Although it still surpasses the baseline, this indicates that the lossless caching of the layers does not uniformly exist across all models. It highlights significant variations between different models, and thus our method is strongly dependent on the structure design of the trained diffusion models. Another limitation of our method is that the acceleration is capped at 2×\times× because every two steps consist of one full model inference step and one cheaper step. This inherently restricts the maximum achievable acceleration ratio. However, we believe that this approach can be expanded to more than two steps, potentially improving the overall efficiency.

6 Summary

In this paper, we propose a novel acceleration method for diffusion transformers. By interpolating between the computationally inexpensive solution but suboptimal model, and the optimal solution but expensive model, we find there exist some models which would infer much faster and also produce high-fidelity images. To find this we train the router which is continuous when training and would be discretized when inference. Experiments show that our method largely outperforms baselines such as DDIM, DPM-Solver and other cache-based methods.

References

  • [1] Fan Bao, Chongxuan Li, Jun Zhu, and Bo Zhang. Analytic-dpm: an analytic estimate of the optimal reverse variance in diffusion probabilistic models. arXiv preprint arXiv:2201.06503, 2022.
  • [2] Fan Bao, Shen Nie, Kaiwen Xue, Yue Cao, Chongxuan Li, Hang Su, and Jun Zhu. All are worth words: A vit backbone for diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 22669–22679, 2023.
  • [3] Tim Brooks, Bill Peebles, Connor Holmes, Will DePue, Yufei Guo, Li **g, David Schnurr, Joe Taylor, Troy Luhman, Eric Luhman, Clarence Ng, Ricky Wang, and Aditya Ramesh. Video generation models as world simulators. 2024.
  • [4] Ziang Cao, Fangzhou Hong, Tong Wu, Liang Pan, and Ziwei Liu. Large-vocabulary 3d diffusion model with transformer. arXiv preprint arXiv:2309.07920, 2023.
  • [5] Thibault Castells, Hyoung-Kyu Song, Bo-Kyeong Kim, and Shinkook Choi. Ld-pruner: Efficient pruning of latent diffusion models using task-agnostic insights. arXiv preprint arXiv:2404.11936, 2024.
  • [6] Shoufa Chen, Mengmeng Xu, Jiawei Ren, Yuren Cong, Sen He, Yan** Xie, Animesh Sinha, ** Luo, Tao Xiang, and Juan-Manuel Perez-Rua. Gentron: Delving deep into diffusion transformers for image and video generation. arXiv preprint arXiv:2312.04557, 2023.
  • [7] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34:8780–8794, 2021.
  • [8] Tim Dockhorn, Arash Vahdat, and Karsten Kreis. Score-based generative modeling with critically-damped langevin diffusion. arXiv preprint arXiv:2112.07068, 2021.
  • [9] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • [10] Angela Fan, Edouard Grave, and Armand Joulin. Reducing transformer depth on demand with structured dropout. In International Conference on Learning Representations, 2020.
  • [11] Gongfan Fang, Xinyin Ma, and Xinchao Wang. Structural pruning for diffusion models. Advances in neural information processing systems, 36, 2024.
  • [12] Shanghua Gao, Pan Zhou, Ming-Ming Cheng, and Shuicheng Yan. Masked diffusion transformer is a strong image synthesizer. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 23164–23173, 2023.
  • [13] Yefei He, Lu** Liu, **g Liu, Weijia Wu, Hong Zhou, and Bohan Zhuang. Ptqd: Accurate post-training quantization for diffusion models. In A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine, editors, Advances in Neural Information Processing Systems, volume 36, pages 13237–13249. Curran Associates, Inc., 2023.
  • [14] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840–6851, 2020.
  • [15] Rosco Hunter, Łukasz Dudziak, Mohamed S Abdelfattah, Abhinav Mehrotra, Sourav Bhattacharya, and Hongkai Wen. Fast inference through the reuse of attention maps in diffusion models. arXiv preprint arXiv:2401.01008, 2023.
  • [16] Xin **g, Yi Chang, Zijiang Yang, Jiangjian Xie, Andreas Triantafyllopoulos, and Bjoern W Schuller. U-dit tts: U-diffusion vision transformer for text-to-speech. In Speech Communication; 15th ITG Conference, pages 56–60. VDE, 2023.
  • [17] Alexia Jolicoeur-Martineau, Ke Li, Rémi Piché-Taillefer, Tal Kachman, and Ioannis Mitliagkas. Gotta go fast when generating data with score-based models. arXiv preprint arXiv:2105.14080, 2021.
  • [18] Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35:26565–26577, 2022.
  • [19] Bo-Kyeong Kim, Geonmin Kim, Tae-Ho Kim, Thibault Castells, Shinkook Choi, Junho Shin, and Hyoung-Kyu Song. Shortened llama: A simple depth pruning for large language models. arXiv preprint arXiv:2402.02834, 2024.
  • [20] Bo-Kyeong Kim, Hyoung-Kyu Song, Thibault Castells, and Shinkook Choi. Bk-sdm: Architecturally compressed stable diffusion for efficient text-to-image generation. In Workshop on Efficient Systems for Foundation Models@ ICML2023, 2023.
  • [21] Tuomas Kynkäänniemi, Tero Karras, Samuli Laine, Jaakko Lehtinen, and Timo Aila. Improved precision and recall metric for assessing generative models. Advances in neural information processing systems, 32, 2019.
  • [22] Muyang Li, Tianle Cai, Jiaxin Cao, Qinsheng Zhang, Han Cai, Junjie Bai, Yangqing Jia, Ming-Yu Liu, Kai Li, and Song Han. Distrifusion: Distributed parallel inference for high-resolution diffusion models. arXiv preprint arXiv:2402.19481, 2024.
  • [23] Senmao Li, Taihang Hu, Fahad Shahbaz Khan, Linxuan Li, Shiqi Yang, Yaxing Wang, Ming-Ming Cheng, and Jian Yang. Faster diffusion: Rethinking the role of unet encoder in diffusion models. arXiv preprint arXiv:2312.09608, 2023.
  • [24] Xiuyu Li, Yijiang Liu, Long Lian, Huanrui Yang, Zhen Dong, Daniel Kang, Shanghang Zhang, and Kurt Keutzer. Q-diffusion: Quantizing diffusion models. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pages 17535–17545, October 2023.
  • [25] Yanyu Li, Huan Wang, Qing **, Ju Hu, Pavlo Chemerys, Yun Fu, Yanzhi Wang, Sergey Tulyakov, and Jian Ren. Snapfusion: Text-to-image diffusion model on mobile devices within two seconds. Advances in Neural Information Processing Systems, 36, 2024.
  • [26] Shanchuan Lin, Anran Wang, and Xiao Yang. Sdxl-lightning: Progressive adversarial diffusion distillation. arXiv preprint arXiv:2402.13929, 2024.
  • [27] Enshu Liu, Xuefei Ning, Zinan Lin, Huazhong Yang, and Yu Wang. Oms-dpm: Optimizing the model schedule for diffusion probabilistic models. arXiv preprint arXiv:2306.08860, 2023.
  • [28] Huadai Liu, Rongjie Huang, Xuan Lin, Wenqiang Xu, Maozong Zheng, Hong Chen, **zheng He, and Zhou Zhao. Vit-tts: visual text-to-speech with scalable diffusion transformer. arXiv preprint arXiv:2305.12708, 2023.
  • [29] Lu** Liu, Yi Ren, Zhijie Lin, and Zhou Zhao. Pseudo numerical methods for diffusion models on manifolds. arXiv preprint arXiv:2202.09778, 2022.
  • [30] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. Advances in Neural Information Processing Systems, 35:5775–5787, 2022.
  • [31] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver++: Fast solver for guided sampling of diffusion probabilistic models. arXiv preprint arXiv:2211.01095, 2022.
  • [32] Haoyu Lu, Guoxing Yang, Nanyi Fei, Yuqi Huo, Zhiwu Lu, ** Luo, and Mingyu Ding. Vdt: General-purpose video diffusion transformers via mask modeling. In The Twelfth International Conference on Learning Representations, 2023.
  • [33] Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao. Latent consistency models: Synthesizing high-resolution images with few-step inference. arXiv preprint arXiv:2310.04378, 2023.
  • [34] Xinyin Ma, Gongfan Fang, and Xinchao Wang. Deepcache: Accelerating diffusion models for free. arXiv preprint arXiv:2312.00858, 2023.
  • [35] Shentong Mo, Enze Xie, Ruihang Chu, Lanqing Hong, Matthias Niessner, and Zhenguo Li. Dit-3d: Exploring plain diffusion transformers for 3d shape generation. Advances in Neural Information Processing Systems, 36, 2024.
  • [36] Taehong Moon, Moonseok Choi, EungGu Yun, Jongmin Yoon, Gayoung Lee, and Juho Lee. Early exiting for accelerated inference in diffusion models. In ICML 2023 Workshop on Structured Probabilistic Inference {{\{{\\\backslash\&}}\}} Generative Modeling, 2023.
  • [37] Charlie Nash, Jacob Menick, Sander Dieleman, and Peter W Battaglia. Generating images with sparse representations. arXiv preprint arXiv:2103.03841, 2021.
  • [38] Zizheng Pan, Bohan Zhuang, De-An Huang, Weili Nie, Zhiding Yu, Chaowei Xiao, Jianfei Cai, and Anima Anandkumar. T-stitch: Accelerating sampling in pre-trained diffusion models with trajectory stitching. arXiv preprint arXiv:2402.14167, 2024.
  • [39] Zizheng Pan, Bohan Zhuang, De-An Huang, Weili Nie, Zhiding Yu, Chaowei Xiao, Jianfei Cai, and Anima Anandkumar. T-stitch: Accelerating sampling in pre-trained diffusion models with trajectory stitching, 2024.
  • [40] William Peebles and Saining Xie. Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4195–4205, 2023.
  • [41] David Raposo, Sam Ritter, Blake Richards, Timothy Lillicrap, Peter Conway Humphreys, and Adam Santoro. Mixture-of-depths: Dynamically allocating compute in transformer-based language models. arXiv preprint arXiv:2404.02258, 2024.
  • [42] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10684–10695, 2022.
  • [43] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pages 234–241. Springer, 2015.
  • [44] Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. Advances in neural information processing systems, 29, 2016.
  • [45] Tim Salimans and Jonathan Ho. Progressive distillation for fast sampling of diffusion models. arXiv preprint arXiv:2202.00512, 2022.
  • [46] Yuzhang Shang, Zhihang Yuan, Bin Xie, Bingzhe Wu, and Yan Yan. Post-training quantization on diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 1972–1981, 2023.
  • [47] Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, and Nima Anari. Parallel sampling of diffusion models. arXiv preprint arXiv:2305.16317, 2023.
  • [48] Alan Jay Smith. Cache memories. ACM Computing Surveys (CSUR), 14(3):473–530, 1982.
  • [49] Junhyuk So, Jungwon Lee, and Eunhyeok Park. Frdiff: Feature reuse for exquisite zero-shot acceleration of diffusion models. arXiv preprint arXiv:2312.03517, 2023.
  • [50] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. arXiv preprint arXiv:2010.02502, 2020.
  • [51] Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. Consistency models. arXiv preprint arXiv:2303.01469, 2023.
  • [52] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32, 2019.
  • [53] Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, 2020.
  • [54] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • [55] Kafeng Wang, Jianfei Chen, He Li, Zhenpeng Mi, and Jun Zhu. Sparsedm: Toward sparse efficient diffusion models, 2024.
  • [56] Felix Wimbauer, Bichen Wu, Edgar Schoenfeld, Xiaoliang Dai, Ji Hou, Zijian He, Artsiom Sanakoyeu, Peizhao Zhang, Sam Tsai, Jonas Kohler, et al. Cache me if you can: Accelerating diffusion models through block caching. arXiv preprint arXiv:2312.03209, 2023.
  • [57] Shuai Yang, Yukang Chen, Luozhou Wang, Shu Liu, and Yingcong Chen. Denoising diffusion step-aware models. arXiv preprint arXiv:2310.03337, 2023.
  • [58] Xingyi Yang and Xinchao Wang. Hash3d: Training-free acceleration for 3d generation. arXiv preprint arXiv:2404.06091, 2024.
  • [59] Xingyi Yang, Daquan Zhou, Jiashi Feng, and Xinchao Wang. Diffusion probabilistic model made slim. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 22552–22562, 2023.
  • [60] Xiulong Yang, Sheng-Min Shih, Yinlin Fu, Xiaoting Zhao, and Shihao Ji. Your vit is secretly a hybrid discriminative-generative diffusion model. arXiv preprint arXiv:2208.07791, 2022.
  • [61] Dingkun Zhang, Sijia Li, Chen Chen, Qingsong Xie, and Haonan Lu. Laptop-diff: Layer pruning and normalized distillation for compressing diffusion models. arXiv preprint arXiv:2404.11098, 2024.
  • [62] Qinsheng Zhang and Yongxin Chen. Fast sampling of diffusion models with exponential integrator. arXiv preprint arXiv:2204.13902, 2022.
  • [63] Qinsheng Zhang, Molei Tao, and Yongxin Chen. gddim: Generalized denoising diffusion implicit models. arXiv preprint arXiv:2206.05564, 2022.
  • [64] Wentian Zhang, Haozhe Liu, **heng Xie, Francesco Faccio, Mike Zheng Shou, and Jürgen Schmidhuber. Cross-attention makes inference cumbersome in text-to-image diffusion models, 2024.
  • [65] Yang Zhao, Yanwu Xu, Zhisheng Xiao, and Tingbo Hou. Mobilediffusion: Subsecond text-to-image generation on mobile devices. arXiv preprint arXiv:2311.16567, 2023.
  • [66] Hongkai Zheng, Weili Nie, Arash Vahdat, and Anima Anandkumar. Fast training of diffusion models with masked transformers. arXiv preprint arXiv:2306.09305, 2023.

Appendix A Proof

A.1 Two equivalent solutions to obtain 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

To got the solution of 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the following two approaches yield equivalent results:

  1. 1.

    Directly update 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from 𝒙ssubscript𝒙𝑠\boldsymbol{x}_{s}bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT. By the definition, the solution at time t𝑡titalic_t would be:

    𝒙t=αtαs𝒙sσt(eλtλs1)ϵθ(𝒙s,s)subscript𝒙𝑡subscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑠1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{x}_{t}=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_{s}-\sigma_{t}% \left(e^{\lambda_{t}-\lambda_{s}}-1\right)\boldsymbol{\epsilon}_{\theta}\left(% \boldsymbol{x}_{s},s\right)bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) (10)
  2. 2.

    First compute 𝒙msubscript𝒙𝑚\boldsymbol{x}_{m}bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT from 𝒙ssubscript𝒙𝑠\boldsymbol{x}_{s}bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, and then compute 𝒙tsubscript𝒙𝑡\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from 𝒙msubscript𝒙𝑚\boldsymbol{x}_{m}bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT with ϵθ(𝒙m,m)=ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)=\boldsymbol{% \epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) = bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s )

Proof. First, we consider the solution of 𝒙msubscript𝒙𝑚\boldsymbol{x}_{m}bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT from 𝒙ssubscript𝒙𝑠\boldsymbol{x}_{s}bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT:

𝒙m=αmαs𝒙sσm(eλmλs1)ϵθ(𝒙s,s)subscript𝒙𝑚subscript𝛼𝑚subscript𝛼𝑠subscript𝒙𝑠subscript𝜎𝑚superscript𝑒subscript𝜆𝑚subscript𝜆𝑠1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{x}_{m}=\frac{\alpha_{m}}{\alpha_{s}}\boldsymbol{x}_{s}-\sigma_{m}% \left(e^{\lambda_{m}-\lambda_{s}}-1\right)\boldsymbol{\epsilon}_{\theta}\left(% \boldsymbol{x}_{s},s\right)bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) (11)

And for the calculation of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with ϵθ(𝒙m,m)=ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)=\boldsymbol{% \epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) = bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ), we have

𝒙tsubscript𝒙𝑡\displaystyle\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =αtαm𝒙mσt(eλtλm1)ϵθ(𝒙m,m)absentsubscript𝛼𝑡subscript𝛼𝑚subscript𝒙𝑚subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑚1subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\displaystyle=\frac{\alpha_{t}}{\alpha_{m}}\boldsymbol{x}_{m}-\sigma_{t}\left(% e^{\lambda_{t}-\lambda_{m}}-1\right)\boldsymbol{\epsilon}_{\theta}\left(% \boldsymbol{x}_{m},m\right)= divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m )
=αtαm(αmαs𝒙sσm(eλmλs1)ϵθ(𝒙s,s))σt(eλtλm1)ϵθ(𝒙s,s)absentsubscript𝛼𝑡subscript𝛼𝑚subscript𝛼𝑚subscript𝛼𝑠subscript𝒙𝑠subscript𝜎𝑚superscript𝑒subscript𝜆𝑚subscript𝜆𝑠1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑚1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\displaystyle=\frac{\alpha_{t}}{\alpha_{m}}\left(\frac{\alpha_{m}}{\alpha_{s}}% \boldsymbol{x}_{s}-\sigma_{m}\left(e^{\lambda_{m}-\lambda_{s}}-1\right)% \boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)\right)-\sigma_% {t}\left(e^{\lambda_{t}-\lambda_{m}}-1\right)\boldsymbol{\epsilon}_{\theta}% \left(\boldsymbol{x}_{s},s\right)= divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ( divide start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) ) - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s )
=αtαs𝒙s(αtαmσm(eλmλs1)+σt(eλtλm1))ϵθ(𝒙s,s)absentsubscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝛼𝑡subscript𝛼𝑚subscript𝜎𝑚superscript𝑒subscript𝜆𝑚subscript𝜆𝑠1subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑚1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\displaystyle=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_{s}-\left(\frac{% \alpha_{t}}{\alpha_{m}}\sigma_{m}\left(e^{\lambda_{m}-\lambda_{s}}-1\right)+% \sigma_{t}\left(e^{\lambda_{t}-\lambda_{m}}-1\right)\right)\boldsymbol{% \epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)= divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - ( divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) (12)

Note that λt=log(αt/σt)subscript𝜆𝑡subscript𝛼𝑡subscript𝜎𝑡\lambda_{t}=\log\left(\alpha_{t}/\sigma_{t}\right)italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_log ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). We obtain:

𝒙tsubscript𝒙𝑡\displaystyle\boldsymbol{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =αtαs𝒙s(αtαmσm(αmσmσsαs1)+σt(αtσtσmαm1))ϵθ(𝒙s,s)absentsubscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝛼𝑡subscript𝛼𝑚subscript𝜎𝑚subscript𝛼𝑚subscript𝜎𝑚subscript𝜎𝑠subscript𝛼𝑠1subscript𝜎𝑡subscript𝛼𝑡subscript𝜎𝑡subscript𝜎𝑚subscript𝛼𝑚1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\displaystyle=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_{s}-\left(\frac{% \alpha_{t}}{\alpha_{m}}\sigma_{m}\left(\frac{\alpha_{m}}{\sigma_{m}}\frac{% \sigma_{s}}{\alpha_{s}}-1\right)+\sigma_{t}\left(\frac{\alpha_{t}}{\sigma_{t}}% \frac{\sigma_{m}}{\alpha_{m}}-1\right)\right)\boldsymbol{\epsilon}_{\theta}% \left(\boldsymbol{x}_{s},s\right)= divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - ( divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( divide start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG divide start_ARG italic_σ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG - 1 ) + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG divide start_ARG italic_σ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG - 1 ) ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s )
=αtαs𝒙s(αtσsαsσt)ϵθ(𝒙s,s)=αtαs𝒙sσt(eλtλs1)ϵθ(𝒙s,s)absentsubscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝛼𝑡subscript𝜎𝑠subscript𝛼𝑠subscript𝜎𝑡subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠subscript𝛼𝑡subscript𝛼𝑠subscript𝒙𝑠subscript𝜎𝑡superscript𝑒subscript𝜆𝑡subscript𝜆𝑠1subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\displaystyle=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_{s}-\left(\alpha_{t}% \frac{\sigma_{s}}{\alpha_{s}}-\sigma_{t}\right)\boldsymbol{\epsilon}_{\theta}% \left(\boldsymbol{x}_{s},s\right)=\frac{\alpha_{t}}{\alpha_{s}}\boldsymbol{x}_% {s}-\sigma_{t}\left(e^{\lambda_{t}-\lambda_{s}}-1\right)\boldsymbol{\epsilon}_% {\theta}\left(\boldsymbol{x}_{s},s\right)= divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT divide start_ARG italic_σ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) = divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_λ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - 1 ) bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) (13)

A.2 Layer interpolation and Interpolation \mathcal{I}caligraphic_I

We next show that the following interpolation of the layer would satisfy the interpolation \mathcal{I}caligraphic_I between ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) and ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ) as we define:

L~i(him,m)=him(1αi)(himhis)+g(m)(βif(him)+(1βi)f(his))subscript~𝐿𝑖superscriptsubscript𝑖𝑚𝑚superscriptsubscript𝑖𝑚1subscript𝛼𝑖superscriptsubscript𝑖𝑚superscriptsubscript𝑖𝑠𝑔𝑚subscript𝛽𝑖𝑓superscriptsubscript𝑖𝑚1subscript𝛽𝑖𝑓superscriptsubscript𝑖𝑠\displaystyle\tilde{L}_{i}(h_{i}^{m},m)=h_{i}^{m}-(1-\alpha_{i})\cdot(h_{i}^{m% }-h_{i}^{s})+g(m)\left(\beta_{i}\cdot f(h_{i}^{m})+(1-\beta_{i})\cdot f(h_{i}^% {s})\right)over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_m ) = italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT - ( 1 - italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT - italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) + italic_g ( italic_m ) ( italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) + ( 1 - italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) ) (14)

To prove this, we need to show these three things: (1) Interpolation condition, where the function passes through the given two models ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) and ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ); (2) Continuity, where the interpolation function is continuous and (3) Differentiability, where the function is differentiable. Since βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and αisubscript𝛼𝑖\alpha_{i}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are continuous and the model also satisfies these conditions, the only thing that needs to be proved is the first property.

Proof. We show Eq.14 satisfies the interpolation condition of \mathcal{I}caligraphic_I

  • With {αi}i=1Dsuperscriptsubscriptsubscript𝛼𝑖𝑖1𝐷\{\alpha_{i}\}_{i=1}^{D}{ italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and {βi}i=1Dsuperscriptsubscriptsubscript𝛽𝑖𝑖1𝐷\{\beta_{i}\}_{i=1}^{D}{ italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT set to 0, the output of the transformer would be ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s )

    If for i(1,D)𝑖1𝐷i\in(1,D)italic_i ∈ ( 1 , italic_D ), αi=0subscript𝛼𝑖0\alpha_{i}=0italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 and βi=0subscript𝛽𝑖0\beta_{i}=0italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 then

    L~i(him,m)=his+g(m)f(his)subscript~𝐿𝑖superscriptsubscript𝑖𝑚𝑚superscriptsubscript𝑖𝑠𝑔𝑚𝑓superscriptsubscript𝑖𝑠\displaystyle\tilde{L}_{i}(h_{i}^{m},m)=h_{i}^{s}+g(m)\cdot f(h_{i}^{s})over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_m ) = italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT + italic_g ( italic_m ) ⋅ italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) (15)

    The output of the transformer after D𝐷Ditalic_D layer is given by:

    L~D(L~D1(L~1(𝒙s,s)))=ϵθ(𝒙s,𝒔)subscript~𝐿𝐷subscript~𝐿𝐷1subscript~𝐿1subscript𝒙𝑠𝑠subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝒔\displaystyle\tilde{L}_{D}\left(\tilde{L}_{D-1}\left(\ldots\tilde{L}_{1}\left(% \boldsymbol{x}_{s},s\right)\ldots\right)\right)=\boldsymbol{\epsilon}_{\theta}% \left(\boldsymbol{x}_{s},\boldsymbol{s}\right)over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_D - 1 end_POSTSUBSCRIPT ( … over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ) … ) ) = bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , bold_italic_s ) (16)

    Therefor, we get ϵθ(𝒙s,s)subscriptbold-italic-ϵ𝜃subscript𝒙𝑠𝑠\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{s},s\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_s ), one of the endpoint in the interpolation \mathcal{I}caligraphic_I.

  • With {αi}i=1Dsuperscriptsubscriptsubscript𝛼𝑖𝑖1𝐷\{\alpha_{i}\}_{i=1}^{D}{ italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and {βi}i=1Dsuperscriptsubscriptsubscript𝛽𝑖𝑖1𝐷\{\beta_{i}\}_{i=1}^{D}{ italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT set to 1, the output would be ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ). If for i(1,D)𝑖1𝐷i\in(1,D)italic_i ∈ ( 1 , italic_D ), αi=1subscript𝛼𝑖1\alpha_{i}=1italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 and βi=1subscript𝛽𝑖1\beta_{i}=1italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 then

    L~i(him,m)=him+g(m)f(him)subscript~𝐿𝑖superscriptsubscript𝑖𝑚𝑚superscriptsubscript𝑖𝑚𝑔𝑚𝑓superscriptsubscript𝑖𝑚\displaystyle\tilde{L}_{i}(h_{i}^{m},m)=h_{i}^{m}+g(m)\cdot f(h_{i}^{m})over~ start_ARG italic_L end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_m ) = italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT + italic_g ( italic_m ) ⋅ italic_f ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) (17)

    The same as above, we would get ϵθ(𝒙m,m)subscriptbold-italic-ϵ𝜃subscript𝒙𝑚𝑚\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{m},m\right)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_m ), the other endpoint in the interpolation \mathcal{I}caligraphic_I.

Appendix B Additional Experiments

B.1 Shifted cache step for DPM-Solver

Table 6: DPM-Solver with and without Shifted Cache Steps. Here we cache all the layers.
Method NFE Latency Speedup IS FID sFID Precision Recall
DPM-Solver-2 20 7.69 1.00×\times× 263.76 2.57 5.01 82.77 55.71
Cache 20 4.25 1.81×\times× 222.64 5.30 7.87 76.17 54.59
Cache - shifted 20 4.54 1.70×\times× 254.48 2.80 4.70 81.14 55.48

One important trick used in our experiment with DPM-Solver involves shifting the cache step. Specifically, when employing DPM-Solver-2, the cache steps (step here is the model evaluation) are shifted from [2,4,6,8,10,…] to [3,5,7,9,11,…]. This adjustment is necessary because the DPM-Solver-2 requires the first-order derivative of the model ϵθ()subscriptbold-italic-ϵ𝜃\boldsymbol{\epsilon}_{\theta}(\cdot)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) at the current timestep, which is computed by subtracting the output at timestep i𝑖iitalic_i from the output at timestep i+1𝑖1i+1italic_i + 1. If the cache steps were taken at timestep i+1𝑖1i+1italic_i + 1, it would result in an incorrect estimation of the derivative. By shifting the cache step, we ensure the accurate calculation of the derivative of ϵθ()subscriptbold-italic-ϵ𝜃\boldsymbol{\epsilon}_{\theta}(\cdot)bold_italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ). This adjustment significantly impacts the results, as demonstrated in Table 6.

B.2 Layer Dropout v.s. Layer Cache

Here we present further comparisons between layer dropout and layer caching. As illustrated in Figure 7, layer caching significantly outperforms layer dropout, maintaining pixel-wise consistency with the original pipeline. Conversely, when the layers to be dropped are selected by our algorithm, the model can still generate images with correct semantics. However, randomly drop** layers severely compromises the model’s ability to produce acceptable images. Table 7 demonstrates that even a small proportion of layer dropout (around 10%) results in a substantial performance degradation.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 7: The quantitative results for layer drop** and layer caching in Section 4.3. (a) DDIM Pipeline with 20 NFE. (2) Our method L2C with 20 NFE (3) Learn to drop the layers by our algorithm. (4) Randomly drop layers. The results here, except the first line as the baseline, all speed up the inference by around 1.18×\times×-1.19×\times×.
Table 7: Comparison with Layer Dropout
Methods Remove Ratio Latency(s) Speedup IS\uparrow FID\downarrow sFID\downarrow Precision\uparrow Recall\uparrow
Random Drop 60/560 2.718 1.06×\times× 9.66 112.93 153.48 10.56 65.57
Random Drop 170/560 2.439 1.18×\times× 3.36 277.42 171.83 1.23 0.24
Learn-to-Drop 179/560 2.421 1.19×\times× 113.93 17.35 28.46 60.25 52.68
Learn-to-Cache 176/560 2.438 1.18×\times× 226.13 3.47 4.58 79.19 56.47
Table 8: λ𝜆\lambdaitalic_λ and θ𝜃\thetaitalic_θ for training the router
Model DiT-XL/2 DiT-XL/2 DiT-XL/2 DiT-XL/2 DiT-L/2 DiT-L/2 U-ViT-H/2 U-ViT-H/2
NFE 50 20 10 50 50 20 50 20
Resolution 256 256 256 512 256 256 256 256
Sampler DDIM DDIM DDIM DDIM DDIM DDIM DPM-Solver-2 DPM-Solver-2
λ𝜆\lambdaitalic_λ for train 1e-6 5e-6 1e-6 5e-6 1e-6 5e-6 0.1 0.1
θ𝜃\thetaitalic_θ for inference 0.1 0.1 0.1 0.9 0.1 0.1 0.9 0.9
Table 9: Performance with different λ𝜆\lambdaitalic_λ. Threshold θ𝜃\thetaitalic_θ is set to 0.1.
λ𝜆\lambdaitalic_λ Remove Ratio Latency(s) Speedup IS\uparrow FID\downarrow sFID\downarrow Precision\uparrow Recall\uparrow
0 0/560 2.87 1.00×\times× 223.49 3.48 4.89 78.76 57.07
5e-7 129/560 2.55 1.13 ×\times× 222.15 3.49 4.79 78.47 57.36
1e-6 176/560 2.45 1.17 ×\times× 226.13 3.47 4.58 79.19 56.47
5e-6 248/560 2.28 1.26 ×\times× 226.95 3.45 4.64 79.20 55.82
1e-5 300/560 2.15 1.33 ×\times× 223.41 3.70 4.91 78.88 56.36
5e-5 404/560 1.92 1.49 ×\times× 200.60 5.43 6.55 75.06 57.54
1e-4 460/560 1.79 1.60 ×\times× 193.75 6.51 7.71 73.55 56.55

B.3 Effect of the hyper-parameter λ𝜆\lambdaitalic_λ and θ𝜃\thetaitalic_θ

We find in our experiments that the router we learned is not sensitive to the hyper-parameters, including the learning rate, the training epoch, and the hyperparameters in the optimizer. The only one that would affect is the λ𝜆\lambdaitalic_λ for training and the threshold θ𝜃\thetaitalic_θ for inference. We list in Table 8 the λ𝜆\lambdaitalic_λ we use that could reproduce the results in Table 1. Here the difference between DiT and U-ViT for λ𝜆\lambdaitalic_λ comes from the difference in implementation.

The results of using different λ𝜆\lambdaitalic_λ values are presented in Table 9. Note that λ𝜆\lambdaitalic_λ serves as the regularization strength to control the sparsity of the router, and thus there would not exist an optimal λ𝜆\lambdaitalic_λ for all settings. It functions as a trade-off between latency and quality, balancing the speed of inference with the fidelity of the generated images.

Appendix C Social Impact

The acceleration of diffusion transformers provides several positive social impacts, such as reducing the latency and resources required for deploying diffusion models. This enhancement improves the real-time applicability of diffusion transformers and promotes environmental sustainability. By making diffusion models more efficient, our method reduces the computational power needed for both training and inference, leading to lower energy consumption and a reduced carbon footprint. However, it is important to note that our method does not address privacy concerns, nor does it mitigate issues related to bias and fairness in diffusion models. These challenges remain when applying our method.