Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation

Clément Chadebec
Jasper Research &Onur Tasar
Jasper Research &Eyal Benaroche
Jasper Research
&Benjamin Aubin
Jasper Research
Corresponding author - [email protected]Contributed during an internship at Jasper Research
Abstract

In this paper, we propose an efficient, fast, and versatile distillation method to accelerate the generation of pre-trained diffusion models: Flash Diffusion. The method reaches state-of-the-art performances in terms of FID and CLIP-Score for few steps image generation on the COCO2014 and COCO2017 datasets, while requiring only several GPU hours of training and fewer trainable parameters than existing methods. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swap**, super-resolution and using different backbones such as UNet-based denoisers (SD1.5, SDXL) or DiT (Pixart-α𝛼\alphaitalic_α), as well as adapters. In all cases, the method allowed to reduce drastically the number of sampling steps while maintaining very high-quality image generation. The official implementation is available at https://github.com/gojasper/flash-diffusion.

# Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption

(a)
Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Refer to caption
(d)
Refer to caption
(e) Super-resolution
Refer to caption
(f) Inpainting
Refer to caption
(g) Face-Swap**
Refer to caption
(h) Adapters
Figure 2: Inputs (left columns) and generated samples (right columns) using the proposed method for different teacher models and tasks (super-resolution, inpainting, face-swap** and adapters). Samples are generated using 4 Neural Function Evaluations (NFEs).

1 Introduction

Diffusion Models (DM) [64, 17, 67] have proven to be one of the most efficient class of generative models for many tasks ranging from image synthesis [8, 55, 57, 47], video generation [19, 10, 3, 2], audio [29] or 3D [41, 40, 74]. They have raised particular interest and enthusiasm for text-to-image applications [54, 55, 57, 59, 18, 11, 50, 5, 6] where they outperform other approaches. However, their usability for real-time applications remains limited by the intrinsic iterative nature of their sampling mechanism. At inference time, these models aim at iteratively denoising a sample drawn from a Gaussian distribution to finally create a sample belonging to the data distribution. Nonetheless, such a denoising process requires multiple evaluations of a potentially very computationally costly neural function such that their faster counterparts such as Generative Adversarial Networks (GANs) [12] and Variational AutoEncoders (VAEs) [27] still tend to be preferred.

Recently, more efficient solvers [37, 38, 80, 82] or diffusion distillation methods [60, 68, 33, 76, 36, 56, 43, 42, 61, 62, 77] aiming at reducing the number of sampling steps required to generate satisfying samples from a trained diffusion model have emerged to try to tackle this issue. Nonetheless, solvers typically require at least 10 Neural Function Evaluations (NFEs) to produce satisfying samples while distillation methods may require extensive training resources [36, 77, 45] or require an iterative training procedure to update the teacher model throughout training [60, 33, 30] limiting their applications and reach. Moreover, most of the existing distillation methods are tailored for a specific task such as text-to-image. It is still unclear how they would perform on other tasks, using different conditionings and diffusion model architectures. In addition, the most efficient approaches rely on an adversarial training procedure [76, 61, 33] potentially leading to unstable trainings and requiring extensive hyper-parameters tuning [62].

In this paper, we present Flash Diffusion, a fast, robust, and versatile diffusion distillation method that allows to drastically reduce the number of sampling steps while maintaining a very high image generation quality. The proposed method aims at training a student model to predict in a single step a denoised multiple-step teacher prediction of a corrupted input sample. The method also drives the student distribution towards the real input sample manifold with an adversarial objective [12] and ensures that it does not drift too much from the learned teacher distribution using distribution matching [9, 31].

The method is compatible with LoRA [20] and is able to generate high-quality samples in only a few steps while requiring only several GPU hours of training and fewer trainable parameters than existing methods. We show that the method is able to reach State-of-the-Art (SOTA) performances in terms of FID and CLIP score for few steps image generation on the COCO2014 and COCO2017 datasets. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swap**, image upscaling and using different diffusion models backbones (SD1.5 [57], SDXL [50] and Pixart-α𝛼\alphaitalic_α [5]) as well as adapters [46]. In all cases, the method allows to reduce drastically the number of sampling steps while maintaining very high-quality image generation.

The main contributions of the paper are as follows:

  • We propose an efficient, fast, versatile, and LoRA compatible distillation method aiming at reducing the number of sampling steps required to generate high-quality samples from a trained diffusion model.

  • We validate the method for a text-to-image task and show that it is able to produce SOTA results for few steps image generation on standard benchmark datasets with only two NFEs, which is equivalent to a single step with classifier-free guidance while having far fewer training parameters than competitors and requiring only a few GPU hours of training.

  • We conduct an extensive ablation study to show the impact of the different components of the method and demonstrate its robustness and reliability.

  • We emphasize the versatility of the method through an extensive experimental study across various tasks (text-to-image, image inpainting, super-resolution, face-swap**), diffusion model architectures (SD1.5, SDXL and Pixart-α𝛼\alphaitalic_α) and illustrate its compatibility with adapters [46].

  • Finally, we share our implementation that showcases the method’s reproducibility and stability, and potentially opens the door to a wider range of practical applications of Flash Diffusion. See https://github.com/gojasper/flash-diffusion.

2 Related Works

Diffusion Models

Diffusion models consist in artificially corrupting input data according to a given noise schedule [64, 17, 67] such that the data distribution eventually resembles a standard Gaussian one. They are then trained to estimate the amount of noise added in order to learn a reverse diffusion process allowing them, once trained, to generate new samples from Gaussian noise. Those models can be conditioned with respect to various inputs such as images [57], depth maps, edges, poses [79, 46] or text [8, 55, 57, 47, 11, 18, 50] where they demonstrated very impressive results. However, the need to recourse to a large number of sampling steps (typically 50 steps) at inference time to generate high-quality samples has limited their application for real-time applications and narrowed their usability and reach.

Diffusion Distillation

In order to tackle this limitation, several methods have recently emerged to reduce the number of function evaluations required at inference time. On the one hand, several papers tried to build more efficient solvers to speed up the generation process [37, 38, 80, 82] but these methods still require the use of several steps (typically 10) to generate satisfying samples. On the other hand, several approaches relying on model distillation [15] proposed to train a student network that would learn to match the samples generated by a teacher model but in fewer steps. A simple approach would consist in building pairs of noise/teacher samples and training a student model to match the teacher predictions from the same noise in a single step with a regression loss [39, 83]. Nonetheless, this approach remains quite limited and struggles to match the quality of the teacher model since there is no underlying useful information to be learned by the student in full noise. Building upon this idea, several methods were proposed to first apply the forward diffusion process to an input sample and then pass it to the student network. The student prediction is then compared to the learned distribution of the teacher model using either a regression loss [28, 77] an adversarial objective [76, 61, 62, 78] or distribution matching [77, 78].

Progressive distillation [60, 45] is also a method that has proven to be quite promising. It consists in training a student model to predict a two-step teacher denoising of a noisy sample in a single step theoretically halving the number of required sampling steps. The teacher is then replaced by the new student and the process is repeated several times. This approach was also enriched with a GAN-based objective that allows to further reduce the number of sampling steps needed from 4-8 to a single pass [33]. InstaFlow [36] proposed instead to rely on rectified flows [35] to ease the one-step distillation process. However, this approach may require a significant number of training parameters and a long training procedure, making it computationally intensive.

Consistency models [68, 65, 42, 24] is also a promising, effective, and one of the most versatile distillation methods proposed in the literature. The main idea is to train a model to map any point lying on the Probability Flow Ordinary Differential Equation (PF-ODE) to its origin, theoretically unlocking single-step generation. Luo et al. [43] combined Latent Consistency Model (LCM) and LoRAs [20] and showed that it is possible to train a strong student with a very limited number of trainable parameters and a few GPU hours of training. Nonetheless, those models still struggle to achieve single-step generation and reach the sampling quality of peers.

In a parallel study conducted recently, the authors of [78] also introduced the combined use of a distribution matching loss and an adversarial loss, a method we also employ in our paper. Nonetheless, they do not rely on the use of a distillation loss that proved highly efficient in our experiments and do not compute the adversarial loss with respect to the same inputs. Moreover, their approach still necessitates training an other denoiser to assess the score of the fake samples, significantly increasing the number of trainable parameters and, consequently, the computational burden of the method. Furthermore, the ability of their method to generalize and perform effectively across different tasks and diffusion model architectures remains unclear.

3 Background

3.1 Diffusion Models

Let x0𝒳subscript𝑥0𝒳x_{0}\in\mathcal{X}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ caligraphic_X be a set of input data such that x0p(x0)similar-tosubscript𝑥0𝑝subscript𝑥0x_{0}\sim p(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) where p(x0)𝑝subscript𝑥0p(x_{0})italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is an unknown distribution. Diffusion models are a class of generative models that define a Markovian process (xt)t[0,T]subscriptsubscript𝑥𝑡𝑡0𝑇(x_{t})_{t\in[0,T]}( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT consisting in creating a noisy version xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT of x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT by iteratively injecting Gaussian noise to the data x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. This process is such that as t𝑡titalic_t increases the distribution of the noisy samples xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT eventually becomes equivalent to an isotropic Gaussian distribution. The noise schedule is controlled by two differentiable functions α(t)𝛼𝑡\alpha(t)italic_α ( italic_t ), σ(t)𝜎𝑡\sigma(t)italic_σ ( italic_t ) for any t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ] such that the log signal-to-noise ratio log[α(t)2/σ(t)2]𝛼superscript𝑡2𝜎superscript𝑡2\log[\alpha(t)^{2}/\sigma(t)^{2}]roman_log [ italic_α ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] is decreasing over time. Given any t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ], the distribution of the noisy samples given the input q(xt|x0)𝑞conditionalsubscript𝑥𝑡subscript𝑥0q(x_{t}|x_{0})italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is called the forward process and is defined by q(xt|x0)=𝒩(xt;α(t)x0,σ(t)2𝐈)𝑞conditionalsubscript𝑥𝑡subscript𝑥0𝒩subscript𝑥𝑡𝛼𝑡subscript𝑥0𝜎superscript𝑡2𝐈q(x_{t}|x_{0})=\mathcal{N}\left(x_{t};\alpha(t)\cdot x_{0},\sigma(t)^{2}\cdot% \mathbf{I}\right)italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α ( italic_t ) ⋅ italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ bold_I ) from which we can sample as follows:

xt=α(t)x0+σ(t)εwithε𝒩(0,𝐈).formulae-sequencesubscript𝑥𝑡𝛼𝑡subscript𝑥0𝜎𝑡𝜀withsimilar-to𝜀𝒩0𝐈x_{t}=\alpha(t)\cdot x_{0}+\sigma(t)\cdot\varepsilon\hskip 10.00002pt\text{% with}\hskip 10.00002pt\varepsilon\sim\mathcal{N}\left(0,\mathbf{I}\right)\,.italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α ( italic_t ) ⋅ italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ ( italic_t ) ⋅ italic_ε with italic_ε ∼ caligraphic_N ( 0 , bold_I ) . (1)

The main idea of diffusion models is to learn to denoise a noisy sample xtq(xt|x0)similar-tosubscript𝑥𝑡𝑞conditionalsubscript𝑥𝑡subscript𝑥0x_{t}\sim q(x_{t}|x_{0})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) in order to learn the reverse process allowing to ultimately create samples x~0subscript~𝑥0\tilde{x}_{0}over~ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT directly from pure noise. In practice, during training a diffusion model consists in learning a parametrized function xθsubscript𝑥𝜃x_{\theta}italic_x start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT conditioned on the timestep t𝑡titalic_t and taking as input the noisy sample xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT such that it predicts a denoised version of the original sample x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The parameters θ𝜃\thetaitalic_θ are then learned via denoising score matching [69, 66].

=𝔼x0p(x0),tπ(t),ε𝒩(0,𝐈)[λ(t)xθ(xt,t)x02],subscript𝔼formulae-sequencesimilar-tosubscript𝑥0𝑝subscript𝑥0formulae-sequencesimilar-to𝑡𝜋𝑡similar-to𝜀𝒩0𝐈delimited-[]𝜆𝑡superscriptnormsubscript𝑥𝜃subscript𝑥𝑡𝑡subscript𝑥02\mathcal{L}=\mathbb{E}_{x_{0}\sim p(x_{0}),t\sim\pi(t),\varepsilon\sim\mathcal% {N}\left(0,\mathbf{I}\right)}\left[\lambda(t)\left\|x_{\theta}(x_{t},t)-x_{0}% \right\|^{2}\right]\,,caligraphic_L = blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_t ∼ italic_π ( italic_t ) , italic_ε ∼ caligraphic_N ( 0 , bold_I ) end_POSTSUBSCRIPT [ italic_λ ( italic_t ) ∥ italic_x start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , (2)

where λ(t)𝜆𝑡\lambda(t)italic_λ ( italic_t ) is a scaling factor that depends on the timestep t[0,1]𝑡01t\in[0,1]italic_t ∈ [ 0 , 1 ] and π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) is a distribution over the timesteps. Note that Eq. (2) is actually equivalent to learning a function εθsubscript𝜀𝜃\varepsilon_{\theta}italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT estimating the amount of noise ε𝜀\varepsilonitalic_ε added to the original sample using the repametrization εθ(xt,t)=(xtα(t)xθ(xt,t))/σ(t)subscript𝜀𝜃subscript𝑥𝑡𝑡subscript𝑥𝑡𝛼𝑡subscript𝑥𝜃subscript𝑥𝑡𝑡𝜎𝑡\varepsilon_{\theta}(x_{t},t)=\big{(}x_{t}-\alpha(t)\cdot x_{\theta}(x_{t},t)% \big{)}/\sigma(t)italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α ( italic_t ) ⋅ italic_x start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) / italic_σ ( italic_t ). Song et al. [67] showed that εθsubscript𝜀𝜃\varepsilon_{\theta}italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT can be used to generate new data points from Gaussian noise by solving the following PF-ODE [67, 60, 25, 37]:

dxt=[f(xt,t)12g2(t)logpθ(xt)]dt,dsubscript𝑥𝑡delimited-[]𝑓subscript𝑥𝑡𝑡12superscript𝑔2𝑡subscript𝑝𝜃subscript𝑥𝑡d𝑡\mathrm{d}x_{t}=\left[f(x_{t},t)-\frac{1}{2}g^{2}(t)\nabla\log p_{\theta}(x_{t% })\right]\mathrm{d}t\,,roman_d italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = [ italic_f ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) ∇ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] roman_d italic_t , (3)

where f(xt,t)𝑓subscript𝑥𝑡𝑡f(x_{t},t)italic_f ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) and g(t)𝑔𝑡g(t)italic_g ( italic_t ) are respectively the drift and diffusion functions of the PF-ODE defined as follows:

f(xt,t)=dlogα(t)dtxt,g2(t)=dσ(t)2dt2dlogα(t)dtσ2(t).formulae-sequence𝑓subscript𝑥𝑡𝑡d𝛼𝑡d𝑡subscript𝑥𝑡superscript𝑔2𝑡d𝜎superscript𝑡2d𝑡2d𝛼𝑡d𝑡superscript𝜎2𝑡f(x_{t},t)=\frac{\mathrm{d}\log\alpha(t)}{\mathrm{d}t}x_{t},\hskip 10.00002ptg% ^{2}(t)=\frac{\mathrm{d}\sigma(t)^{2}}{\mathrm{d}t}-2\frac{\mathrm{d}\log% \alpha(t)}{\mathrm{d}t}\sigma^{2}(t)\,.italic_f ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = divide start_ARG roman_d roman_log italic_α ( italic_t ) end_ARG start_ARG roman_d italic_t end_ARG italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) = divide start_ARG roman_d italic_σ ( italic_t ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_t end_ARG - 2 divide start_ARG roman_d roman_log italic_α ( italic_t ) end_ARG start_ARG roman_d italic_t end_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t ) .

logpθ(xt)=εθ(xt,t)σ(t)subscript𝑝𝜃subscript𝑥𝑡subscript𝜀𝜃subscript𝑥𝑡𝑡𝜎𝑡\nabla\log p_{\theta}(x_{t})=-\frac{\varepsilon_{\theta}(x_{t},t)}{\sigma(t)}∇ roman_log italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = - divide start_ARG italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) end_ARG start_ARG italic_σ ( italic_t ) end_ARG is called the score function of pθ(xt)subscript𝑝𝜃subscript𝑥𝑡p_{\theta}(x_{t})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). The PF-ODE can be solved using a neural ODE integrator [7] consisting in iteratively applying the learned function εθsubscript𝜀𝜃\varepsilon_{\theta}italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT according to given update rules such as the Euler [67] or the Heun solver [23].

A conditional diffusion model can be trained to generate samples from a conditional distribution p(x0|c)𝑝conditionalsubscript𝑥0𝑐p(x_{0}|c)italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | italic_c ) by learning conditional denoising functions εθ(xt,t,c)subscript𝜀𝜃subscript𝑥𝑡𝑡𝑐\varepsilon_{\theta}(x_{t},t,c)italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t , italic_c ) or xθ(xt,t,c)subscript𝑥𝜃subscript𝑥𝑡𝑡𝑐x_{\theta}(x_{t},t,c)italic_x start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t , italic_c ) [54, 55, 57, 59, 18, 11, 50, 5, 6]. In that particular setting, Classifier-Free Guidance (CFG) [16] has proven to be a very efficient way to better enforce the model to respect the conditioning and so improve the sampling quality. CFG is a technique that consists in drop** the conditioning c𝑐citalic_c with a certain probability during training and replacing the conditional noise estimate εθ(xt,t,c)subscript𝜀𝜃subscript𝑥𝑡𝑡𝑐\varepsilon_{\theta}(x_{t},t,c)italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t , italic_c ) with a linear combination at inference time as follows:

εθ(xt,t,c)=ωεθ(xt,t,c)+(1ω)εθ(xt,t,),subscript𝜀𝜃subscript𝑥𝑡𝑡𝑐𝜔subscript𝜀𝜃subscript𝑥𝑡𝑡𝑐1𝜔subscript𝜀𝜃subscript𝑥𝑡𝑡\varepsilon_{\theta}(x_{t},t,c)=\omega\cdot\varepsilon_{\theta}(x_{t},t,c)+(1-% \omega)\cdot\varepsilon_{\theta}(x_{t},t,\varnothing)\,,italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t , italic_c ) = italic_ω ⋅ italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t , italic_c ) + ( 1 - italic_ω ) ⋅ italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t , ∅ ) , (4)

where ω>0𝜔0\omega>0italic_ω > 0 is called the guidance scale.

3.2 Consistency Models

Since our approach is inspired by the idea exposed in consistency models [68, 42], we recall some elements of those models. Consistency Models (CM) are a new class of generative models designed primarily to learn a consistency function fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT that maps any sample xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT lying on a trajectory of the PF-ODE given in Eq. (3) directly to the original sample x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT while ensuring the self-consistency property for any t[ε,T]𝑡𝜀𝑇t\in[\varepsilon,T]italic_t ∈ [ italic_ε , italic_T ], ε>0𝜀0\varepsilon>0italic_ε > 0 [68, 42, 65]:

fθ(xt,t)=fθ(xt,t),(t,t)[ε,T]2.formulae-sequencesubscript𝑓𝜃subscript𝑥𝑡𝑡subscript𝑓𝜃subscript𝑥superscript𝑡superscript𝑡for-all𝑡superscript𝑡superscript𝜀𝑇2f_{\theta}(x_{t},t)=f_{\theta}(x_{t^{\prime}},t^{\prime}),\hskip 10.00002pt% \forall(t,t^{\prime})\in[\varepsilon,T]^{2}\,.italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , ∀ ( italic_t , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ [ italic_ε , italic_T ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (5)

In order to ensure the self-consistency property, the authors of [68] proposed to parametrized fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT as follows:

fθ(xt,t)=cskip(t)xt+cout(t)Fθ(xt,t),subscript𝑓𝜃subscript𝑥𝑡𝑡subscript𝑐skip𝑡subscript𝑥𝑡subscript𝑐out𝑡subscript𝐹𝜃subscript𝑥𝑡𝑡f_{\theta}(x_{t},t)=c_{\mathrm{skip}}(t)\cdot x_{t}+c_{\mathrm{out}}(t)\cdot F% _{\theta}(x_{t},t)\,,italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = italic_c start_POSTSUBSCRIPT roman_skip end_POSTSUBSCRIPT ( italic_t ) ⋅ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_c start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT ( italic_t ) ⋅ italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ,

where Fθsubscript𝐹𝜃F_{\theta}italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is parametrized using a neural network and cskipsubscript𝑐skipc_{\mathrm{skip}}italic_c start_POSTSUBSCRIPT roman_skip end_POSTSUBSCRIPT and coutsubscript𝑐outc_{\mathrm{out}}italic_c start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT are differentiable functions [68, 42]. A consistency model can be trained either from scratch (Consistency Training) or can be used to distill an existing DM (Consistency Distillation) [68, 42]. In both cases, the objective of the model is to learn fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT such that it matches the output of a target function fθsubscript𝑓superscript𝜃f_{\theta^{-}}italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT end_POSTSUBSCRIPT the weights of which are updated using Exponential Moving Average (EMA), for any given points (xt,xt)subscript𝑥𝑡subscript𝑥superscript𝑡(x_{t},x_{t^{\prime}})( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) lying on a trajectory of the PF-ODE:

=𝔼x0p(x0),tπ(t),ε𝒩(0,𝐈)[fθ(xt,t)fθ(xt,t)2],subscript𝔼formulae-sequencesimilar-tosubscript𝑥0𝑝subscript𝑥0formulae-sequencesimilar-to𝑡𝜋𝑡similar-to𝜀𝒩0𝐈delimited-[]superscriptnormsubscript𝑓𝜃subscript𝑥𝑡𝑡subscript𝑓superscript𝜃subscript𝑥superscript𝑡superscript𝑡2\mathcal{L}=\mathbb{E}_{x_{0}\sim p(x_{0}),t\sim\pi(t),\varepsilon\sim\mathcal% {N}\left(0,\mathbf{I}\right)}\left[\left\|f_{\theta}(x_{t},t)-f_{\theta^{-}}(x% _{t^{\prime}},t^{\prime})\right\|^{2}\right]\,,caligraphic_L = blackboard_E start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_t ∼ italic_π ( italic_t ) , italic_ε ∼ caligraphic_N ( 0 , bold_I ) end_POSTSUBSCRIPT [ ∥ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ,

In other words, given a noisy sample xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT obtained with Eq. (1), the idea is to enforce that fθ(xt,t)=fθ(xt,t)subscript𝑓𝜃subscript𝑥𝑡𝑡subscript𝑓superscript𝜃subscript𝑥superscript𝑡superscript𝑡f_{\theta}(x_{t},t)=f_{\theta^{-}}(x_{t^{\prime}},t^{\prime})italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) = italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) where xtsubscript𝑥superscript𝑡x_{t^{\prime}}italic_x start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is obtained using either Eq. (1) with the same noise ε𝜀\varepsilonitalic_ε and input x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT for Consistency Training [68, 65] or using a trained diffusion model εϕteachersuperscriptsubscript𝜀italic-ϕteacher\varepsilon_{\phi}^{\mathrm{teacher}}italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT and an ODE solver ΨΨ\Psiroman_Ψ for Consistency Distillation [68, 65]. Once the model is trained, one may theoretically generate a sample x~0subscript~𝑥0\tilde{x}_{0}over~ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT in a single step by first drawing a noisy sample xT𝒩(0,𝐈)similar-tosubscript𝑥𝑇𝒩0𝐈x_{T}\sim\mathcal{N}\left(0,\mathbf{I}\right)italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , bold_I ) and then applying the learned function fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT to it. In practice, several iterations are required to generate a satisfying sample and so the estimated sample x~0subscript~𝑥0\tilde{x}_{0}over~ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is iteratively re-noised and denoised several times using the learned function fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT.

4 Proposed Method

In this section, we expose the proposed method that builds upon several ideas proposed in the literature. For the following, we place ourselves in the context of Latent Diffusion Models [57] for image generation and refer to the teacher model as εϕteachersuperscriptsubscript𝜀italic-ϕteacher\varepsilon_{\phi}^{\mathrm{teacher}}italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT, the student model as εθstudentsuperscriptsubscript𝜀𝜃student\varepsilon_{\theta}^{\mathrm{student}}italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT, the training images as x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and their unknown distribution p(x0)𝑝subscript𝑥0p(x_{0})italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). We refer to z0=(x0)subscript𝑧0subscript𝑥0z_{0}=\mathcal{E}(x_{0})italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = caligraphic_E ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) as the associated latent variables obtained with an encoder \mathcal{E}caligraphic_E. We denote π𝜋\piitalic_π, the probability density function of the timesteps, and set T=1𝑇1T=1italic_T = 1. Note that the presented approach also applies straightforwardly to pixel-space diffusion models.

4.1 Distilling a Pretrained Diffusion Model

The proposed method is mainly driven by the desire to end up with a fast, robust, and reliable approach that would be easily transposed to different use cases. Given a set of data x0𝒳subscript𝑥0𝒳x_{0}\in\mathcal{X}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ caligraphic_X such that x0p(x0)similar-tosubscript𝑥0𝑝subscript𝑥0x_{0}\sim p(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and z0=(x0)subscript𝑧0subscript𝑥0z_{0}=\mathcal{E}(x_{0})italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = caligraphic_E ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) the associated latent variables, the main idea of the proposed approach is quite similar to diffusion models. Given a noise schedule defined by α(t)𝛼𝑡\alpha(t)italic_α ( italic_t ) and σ(t)𝜎𝑡\sigma(t)italic_σ ( italic_t ), we propose to create a noisy latent sample ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with tπ(t)similar-to𝑡𝜋𝑡t\sim\pi(t)italic_t ∼ italic_π ( italic_t ) as specified in Eq. (1) and train a function fθstudentsuperscriptsubscript𝑓𝜃studentf_{\theta}^{\mathrm{student}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT to predict a denoised version z~0subscript~𝑧0\tilde{z}_{0}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT of the original sample z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The main difference with a diffusion model is that instead of using z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as a target, we propose to leverage the knowledge of the teacher model and use a sample belonging to the data distribution learned by the teacher model pϕteacher(z0)superscriptsubscript𝑝italic-ϕteachersubscript𝑧0p_{\phi}^{\mathrm{teacher}}(z_{0})italic_p start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). In other words, we use the teacher model and an ODE solver to generate a denoised latent sample z~0teacher(zt)superscriptsubscript~𝑧0teachersubscript𝑧𝑡\tilde{z}_{0}^{\mathrm{teacher}}(z_{t})over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) that belongs to the learned data distribution and use it as a target for the student model. The main distillation loss writes as follows:

distil=𝔼z0,t,ε[fθstudent(zt,t)z~0teacher(zt)2],subscriptdistilsubscript𝔼subscript𝑧0𝑡𝜀delimited-[]superscriptnormsuperscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡superscriptsubscript~𝑧0teachersubscript𝑧𝑡2\mathcal{L}_{\mathrm{distil}}=\mathbb{E}_{z_{0},t,\varepsilon}\left[\left\|f_{% \theta}^{\mathrm{student}}(z_{t},t)-\tilde{z}_{0}^{\mathrm{teacher}}(z_{t})% \right\|^{2}\right]\,,caligraphic_L start_POSTSUBSCRIPT roman_distil end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t , italic_ε end_POSTSUBSCRIPT [ ∥ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) - over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , (6)

where π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) denotes the distribution over the timesteps, and z~0teacher(zt)superscriptsubscript~𝑧0teachersubscript𝑧𝑡\tilde{z}_{0}^{\mathrm{teacher}}(z_{t})over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is obtained by running several steps of an ODE solver ΨΨ\Psiroman_Ψ on a teacher model εϕteachersuperscriptsubscript𝜀italic-ϕteacher\varepsilon_{\phi}^{\mathrm{teacher}}italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT starting from zt=α(t)z0+σ(t)εsubscript𝑧𝑡𝛼𝑡subscript𝑧0𝜎𝑡𝜀z_{t}=\alpha(t)\cdot z_{0}+\sigma(t)\cdot\varepsilonitalic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α ( italic_t ) ⋅ italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ ( italic_t ) ⋅ italic_ε. A similar idea was employed in [62] but the authors generate fully synthetic samples meaning that the samples ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are pure noise, zt𝒩(0,𝐈)similar-tosubscript𝑧𝑡𝒩0𝐈z_{t}\sim\mathcal{N}\left(0,\mathbf{I}\right)italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , bold_I ). In contrast, in our approach, we hypothesize that allowing ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to retain some information from the ground-truth encoded sample z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT could enhance the distillation process. As in [42], when distilling a conditional DM, we also perform Classifier-Free Guidance (CFG) [16] with the teacher to better enforce the model to respect the conditioning. This technique actually significantly improves the quality of the generated samples by the student as shown in Sec. 5. Additionally, it eliminates the need for conducting CFG during inference with the student, further decreasing the method’s computational cost by halving the NFEs for each step. The value of the guidance scale ω𝜔\omegaitalic_ω used during training is part of the ablations presented in Sec. 5.2 but in practice ω𝜔\omegaitalic_ω is uniformly sampled in [ωmin,ωmax]subscript𝜔subscript𝜔[\omega_{\min},\omega_{\max}][ italic_ω start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT , italic_ω start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ] where 0ωminωmax0subscript𝜔subscript𝜔0\leq\omega_{\min}\leq\omega_{\max}0 ≤ italic_ω start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ≤ italic_ω start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT. As mentioned in Sec. 3.2, our approach bears resemblance to existing consistency models [68, 42]. However, rather than relying on a previous instance of the student model to estimate the origin of the PF-ODE, we directly employ the teacher model coupled with an ODE solver to generate the target. We observed that these ingredients enhance the stability of the training procedure.

Refer to caption
(a) Warm-up
Refer to caption
(b) Phase 1
Refer to caption
(c) Phase 2
Refer to caption
(d) Phase 3
Figure 3: Illustration of the evolution of the proposed timesteps distribution π𝜋\piitalic_π throughout training. t=0𝑡0t=0italic_t = 0 corresponds to no noise injection while t=1𝑡1t=1italic_t = 1 corresponds to the maximum noise injection (i.e. the noisy latent sample is equivalent to a sample drawn from a standard Gaussian distribution). For each phase unless the Warm-up, 4 timesteps are over-sampled out of the K=32𝐾32K=32italic_K = 32 selected ones. As the training progresses, the probability mass is shifted towards full noise to favor single-step generation.

4.2 Timesteps Sampling

The cornerstone of our approach hinges on the selection of the timestep probability density function, denoted as π(t)𝜋𝑡\pi(t)italic_π ( italic_t ). According to the continuous modeling, exposed in [67], DMs are trained to remove noise from a latent sample ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT for any given continuous time t𝑡titalic_t. However, since we aim at achieving few steps data generation (typically 1-4 steps) at inference time, the learned function εθsubscript𝜀𝜃\varepsilon_{\theta}italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT will only be evaluated at a few discrete timesteps {ti}i=1Ksuperscriptsubscriptsubscript𝑡𝑖𝑖1𝐾\{t_{i}\}_{i=1}^{K}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT.

In order to tackle this issue and enforce the distillation process to focus on the most relevant timesteps, we propose to select K𝐾Kitalic_K (typically 16, 32, or 64) uniformly spaced timesteps in the range [0,1]01[0,1][ 0 , 1 ] and assign a probability to each of them according to a probability mass function π(t)𝜋𝑡\pi(t)italic_π ( italic_t ). We choose π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) as a mixture of Gaussian distribution controlled by a series of weights {βi}i=1Ksuperscriptsubscriptsubscript𝛽𝑖𝑖1𝐾\{\beta_{i}\}_{i=1}^{K}{ italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT

π(t)=12πσ2i=1Kβiexp((tμi)22σ2),𝜋𝑡12𝜋superscript𝜎2superscriptsubscript𝑖1𝐾subscript𝛽𝑖superscript𝑡subscript𝜇𝑖22superscript𝜎2\pi(t)=\frac{1}{\sqrt{2\pi\sigma^{2}}}\sum_{i=1}^{K}\beta_{i}\exp\left(-\frac{% (t-\mu_{i})^{2}}{2\sigma^{2}}\right),italic_π ( italic_t ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_exp ( - divide start_ARG ( italic_t - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) , (7)

where the mean of each Gaussian is controlled by {μi=i/K}i=1Ksuperscriptsubscriptsubscript𝜇𝑖𝑖𝐾𝑖1𝐾\{\mu_{i}=i/K\}_{i=1}^{K}{ italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_i / italic_K } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT and the variance is fixed to σ=0.5/K2𝜎0.5superscript𝐾2\sigma=\sqrt{0.5/K^{2}}italic_σ = square-root start_ARG 0.5 / italic_K start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG. This approach is such that when distilling the teacher only a small number of K𝐾Kitalic_K discrete timesteps will be sampled instead of the continuous range [0,1]01[0,1][ 0 , 1 ]111In practice when training a DM, the range [0,1]01[0,1][ 0 , 1 ] is actually discretized (typically into 1000 timesteps) for computational purposes.. Moreover, the distribution π𝜋\piitalic_π is defined such that out of the K𝐾Kitalic_K selected timesteps, the 4 timesteps used at inference for 1, 2 and 4 steps generation are over-sampled (typically we set βi>0subscript𝛽𝑖0\beta_{i}>0italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT > 0 if i[K4,K2,3K4,K]𝑖𝐾4𝐾23𝐾4𝐾i\in[\frac{K}{4},\frac{K}{2},\frac{3K}{4},K]italic_i ∈ [ divide start_ARG italic_K end_ARG start_ARG 4 end_ARG , divide start_ARG italic_K end_ARG start_ARG 2 end_ARG , divide start_ARG 3 italic_K end_ARG start_ARG 4 end_ARG , italic_K ] and βi=0subscript𝛽𝑖0\beta_{i}=0italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 otherwise). Unlike other methods [61, 62] we do not only focus on those 4 timesteps since we noticed that it can lead to a reduction of diversity in the generated samples. This is in particular emphasized in the ablation study presented in Sec. 5.2. In practice, we notice that a warm-up phase is beneficial to the training process. Therefore, we decide to start by first imposing a higher probability to the timesteps corresponding to the least added amount of noise by setting βK/4=βK/2=0.5subscript𝛽𝐾4subscript𝛽𝐾20.5\beta_{K/4}=\beta_{K/2}=0.5italic_β start_POSTSUBSCRIPT italic_K / 4 end_POSTSUBSCRIPT = italic_β start_POSTSUBSCRIPT italic_K / 2 end_POSTSUBSCRIPT = 0.5 and βi=0subscript𝛽𝑖0\beta_{i}=0italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 otherwise. We then progressively shift the probability mass towards full noise to favor single-step generation while still over-sampling the targeted 4 timesteps by setting a strictly positive value for βisubscript𝛽𝑖\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT where i0[K/4]𝑖0delimited-[]𝐾4i\equiv 0[K/4]italic_i ≡ 0 [ italic_K / 4 ], and βi=0subscript𝛽𝑖0\beta_{i}=0italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 otherwise. An example for π𝜋\piitalic_π with K=32𝐾32K=32italic_K = 32 is illustrated in Figure 3. As pictured in the figure, the [0,1]01[0,1][ 0 , 1 ] interval is split into 32 timesteps. During the warm-up phase, the probability mass allocates a higher probability to timesteps [0.25,0.5]0.250.5[0.25,0.5][ 0.25 , 0.5 ] to ease the distillation process. As the training progresses, the probability mass function is then shifted towards full noise to favor single-step generation while always allocating a higher probability to the 4 timesteps [0.25,0.5,0.75,1]0.250.50.751[0.25,0.5,0.75,1][ 0.25 , 0.5 , 0.75 , 1 ]. The impact of the timesteps distribution is further discussed in Sec. 5.2.

Refer to caption
Figure 4: Flash Diffusion training method: the student is trained by using a distillation loss between multiple-step teacher and single-step student denoised samples. The student predictions are then re-noised and denoised with the teacher and student before evaluating the adversarial and distribution matching losses.

4.3 Adversarial Objective

To further enhance the quality of the samples, and since it proved very efficient in several works proposed in the literature to achieve few steps image generation [76, 61, 33, 28, 62], we have also decided to incorporate an adversarial objective. The core idea is to train the student model to generate samples that are indistinguishable from the true data distribution p(x0)𝑝subscript𝑥0p(x_{0})italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). To do so, we propose to train a discriminator Dνsubscript𝐷𝜈D_{\nu}italic_D start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT to distinguish the generated samples x~0subscript~𝑥0\tilde{x}_{0}over~ start_ARG italic_x end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT from the real samples x0p(x0)similar-tosubscript𝑥0𝑝subscript𝑥0x_{0}\sim p(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). As proposed in [33, 62], we also apply the discriminator directly within the latent space. This approach circumvents the necessity of decoding the samples using the VAE, a process outlined in [61], that proves to be expensive and hampers the method’s scalability to high-resolution images.

Drawing inspiration from [33, 62], we propose an approach where both the one-step student prediction z~0subscript~𝑧0\tilde{z}_{0}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and the input latent sample z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT are re-noised following the teacher noise schedule. This process uses a timestep tsuperscript𝑡t^{\prime}italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, which is uniformly chosen from the set [0.01,0.25,0.5,0.75]0.010.250.50.75[0.01,0.25,0.5,0.75][ 0.01 , 0.25 , 0.5 , 0.75 ]. The samples are first passed through the frozen teacher model, followed by the discriminator, to yield a real or fake prediction. When employing a UNet architecture [58] for the teacher model, our approach focuses on utilizing only the encoder portion of the UNet, generating an even more compressed latent representation and further reducing the parameter count for the discriminator. We carefully select specific timesteps to enable the discriminator to effectively differentiate between samples based on both high and low-frequency details, as discussed by [33]. Note that in our proposed method, the discriminator is the only component we train, while the teacher model remains frozen. The adversarial loss advsubscriptadv\mathcal{L}_{\mathrm{adv}}caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and discriminator loss discriminatorsubscriptdiscriminator\mathcal{L}_{\mathrm{discriminator}}caligraphic_L start_POSTSUBSCRIPT roman_discriminator end_POSTSUBSCRIPT write as follows:

adv=subscriptadvabsent\displaystyle\mathcal{L}_{\mathrm{adv}}=caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 12𝔼z0,t,ε[Dν(fθstudent(zt,t))12],12subscript𝔼subscript𝑧0superscript𝑡𝜀delimited-[]superscriptnormsubscript𝐷𝜈superscriptsubscript𝑓𝜃studentsubscript𝑧superscript𝑡superscript𝑡12\displaystyle~{}\frac{1}{2}~{}\mathbb{E}_{z_{0},t^{\prime},\varepsilon}\left[% \left\|D_{\nu}(f_{\theta}^{\mathrm{student}}(z_{t^{\prime}},t^{\prime}))-1% \right\|^{2}\right]\,,divide start_ARG 1 end_ARG start_ARG 2 end_ARG blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_ε end_POSTSUBSCRIPT [ ∥ italic_D start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) - 1 ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , (8)
discriminator=subscriptdiscriminatorabsent\displaystyle\mathcal{L}_{\mathrm{discriminator}}=caligraphic_L start_POSTSUBSCRIPT roman_discriminator end_POSTSUBSCRIPT = 12𝔼z0,t,ε[Dν(z0)12+Dν(fθstudent(zt,t))02].12subscript𝔼subscript𝑧0superscript𝑡𝜀delimited-[]superscriptnormsubscript𝐷𝜈subscript𝑧012superscriptnormsubscript𝐷𝜈superscriptsubscript𝑓𝜃studentsubscript𝑧superscript𝑡superscript𝑡02\displaystyle~{}\frac{1}{2}~{}\mathbb{E}_{z_{0},t^{\prime},\varepsilon}\left[% \left\|D_{\nu}(z_{0})-1\right\|^{2}+\left\|D_{\nu}(f_{\theta}^{\mathrm{student% }}(z_{t^{\prime}},t^{\prime}))-0\right\|^{2}\right]\,.divide start_ARG 1 end_ARG start_ARG 2 end_ARG blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_ε end_POSTSUBSCRIPT [ ∥ italic_D start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - 1 ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ italic_D start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) - 0 ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] .

where ν𝜈\nuitalic_ν denotes the discriminator parameters. We opt for these particular losses due to their reliability and stability during training, as observed in our experiments. Additionally, we stress the impact of the chosen adversarial loss advsubscriptadv\mathcal{L}_{\mathrm{adv}}caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT through an ablation study detailed in Section 5.2. In practical terms, the discriminator’s architecture is designed as a straightforward Convolutional Neural Network (CNN) featuring a stride of 2, a kernel size of 4, SiLU activation [13, 53] and group normalization [75].

4.4 Distribution Matching

Inspired by the work of [77], we also propose to introduce a Distribution Matching Distillation (DMD) loss to ensure that the generated samples closely mirror the data distribution learned by the teacher [12, 9, 31]. Specifically, this involves minimizing the Kullback–Leibler (KL) divergence between the distribution of samples from the student model pθstudentsubscriptsuperscript𝑝student𝜃p^{\mathrm{student}}_{\theta}italic_p start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and pϕteachersubscriptsuperscript𝑝teacheritalic-ϕp^{\mathrm{teacher}}_{\phi}italic_p start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT, the data distribution learned by the teacher model [73]:

DMDsubscriptDMD\displaystyle\mathcal{L}_{\mathrm{DMD}}caligraphic_L start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT =DKL(pθstudent||pϕteacher)\displaystyle=~{}D_{KL}(p^{\mathrm{student}}_{\theta}||p^{\mathrm{teacher}}_{% \phi})= italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_p start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT | | italic_p start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ) (9)
=𝔼z0,t,ε[(logpϕteacher(fθstudent(zt,t))logpθstudent(fθstudent(zt,t)))].absentsubscript𝔼subscript𝑧0𝑡𝜀delimited-[]subscriptsuperscript𝑝teacheritalic-ϕsuperscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡subscriptsuperscript𝑝student𝜃superscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡\displaystyle=~{}\mathbb{E}_{z_{0},t,\varepsilon}\bigg{[}-\bigg{(}\log p^{% \mathrm{teacher}}_{\phi}\big{(}f_{\theta}^{\mathrm{student}}(z_{t},t)\big{)}-% \log p^{\mathrm{student}}_{\theta}\big{(}f_{\theta}^{\mathrm{student}}(z_{t},t% )\big{)}\bigg{)}\bigg{]}\,.= blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t , italic_ε end_POSTSUBSCRIPT [ - ( roman_log italic_p start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) - roman_log italic_p start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) ) ] .

Taking the gradient of the KL divergence with respect to the student model parameters θ𝜃\thetaitalic_θ leads to the following update rule:

θDMD=𝔼z0,t,ε[(steacher(fθstudent(zt,t))sstudent(fθstudent(zt,t)))fθstudent(zt,t)],subscript𝜃subscriptDMDsubscript𝔼subscript𝑧0𝑡𝜀delimited-[]superscript𝑠teachersuperscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡superscript𝑠studentsuperscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡superscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡\nabla_{\theta}\mathcal{L}_{\mathrm{DMD}}=\mathbb{E}_{z_{0},t,\varepsilon}% \bigg{[}-\bigg{(}s^{\mathrm{teacher}}\big{(}f_{\theta}^{\mathrm{student}}(z_{t% },t)\big{)}-s^{\mathrm{student}}\big{(}f_{\theta}^{\mathrm{student}}(z_{t},t)% \big{)}\bigg{)}\nabla f_{\theta}^{\mathrm{student}}(z_{t},t)\bigg{]}\,,∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t , italic_ε end_POSTSUBSCRIPT [ - ( italic_s start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) - italic_s start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ) ) ∇ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) ] ,

where steachersuperscript𝑠teachers^{\mathrm{teacher}}italic_s start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT and sstudentsuperscript𝑠students^{\mathrm{student}}italic_s start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT are the score functions of the teacher and student distributions respectively. Inspired by [77], the one-step student prediction is re-noised using a uniformly sampled timestep t′′𝒰([0,1])similar-tosuperscript𝑡′′𝒰01t^{\prime\prime}\sim\mathcal{U}([0,1])italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ∼ caligraphic_U ( [ 0 , 1 ] ) and the teacher noise schedule. The new noisy sample is passed through the frozen teacher model to get the score function for the teacher distribution steacher(fθstudent(zt′′,t′′))=(εϕteacher(xt′′,t′′)/σ(t′′))superscript𝑠teachersuperscriptsubscript𝑓𝜃studentsubscript𝑧superscript𝑡′′superscript𝑡′′superscriptsubscript𝜀italic-ϕteachersubscript𝑥superscript𝑡′′superscript𝑡′′𝜎superscript𝑡′′s^{\mathrm{teacher}}(f_{\theta}^{\mathrm{student}}(z_{t^{\prime\prime}},t^{% \prime\prime}))=-(\varepsilon_{\phi}^{\mathrm{teacher}}(x_{t^{\prime\prime}},t% ^{\prime\prime})/\sigma(t^{\prime\prime}))italic_s start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ) ) = - ( italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ) / italic_σ ( italic_t start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ) ). In our approach, we utilize the student model for the score function of the student distribution, instead of a dedicated diffusion model as referenced in [77]. This choice significantly reduces the number of trainable parameters and computational costs. Moreover, the use of the distribution matching loss, examined in the ablation study section (Sec. 5.2), has proven to enhance the quality of generated samples and their prompt adherence.

Algorithm 1 Flash Diffusion
Input: A trained teacher DM εϕteachersuperscriptsubscript𝜀italic-ϕteacher\varepsilon_{\phi}^{\mathrm{teacher}}italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT, a trainable student DM εθstudentsuperscriptsubscript𝜀𝜃student\varepsilon_{\theta}^{\mathrm{student}}italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT, an ODE solver ΨΨ\Psiroman_Ψ, the number of sampling teacher steps K𝐾Kitalic_K, a timesteps distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ), the guidance scale range [ωmin,ωmax]subscript𝜔minsubscript𝜔max[\omega_{\mathrm{min}},\omega_{\mathrm{max}}][ italic_ω start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT , italic_ω start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ], λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT, λdmdsubscript𝜆dmd\lambda_{\mathrm{dmd}}italic_λ start_POSTSUBSCRIPT roman_dmd end_POSTSUBSCRIPT the losses weights
Initialisation: θϕ𝜃italic-ϕ\theta\leftarrow\phiitalic_θ ← italic_ϕ\triangleright Initialise the student with teacher’s weights
while not converged do
     (z,c)𝒵×𝒞similar-to𝑧𝑐𝒵𝒞(z,c)\sim\mathcal{Z}\times\mathcal{C}( italic_z , italic_c ) ∼ caligraphic_Z × caligraphic_C, ω𝒰([ωmin,ωmax])similar-to𝜔𝒰subscript𝜔minsubscript𝜔max\omega\sim\mathcal{U}\big{(}[\omega_{\mathrm{min}},\omega_{\mathrm{max}}]\big{)}italic_ω ∼ caligraphic_U ( [ italic_ω start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT , italic_ω start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ] )\triangleright Draw a sample and guidance scale
     tiπ(t)similar-tosubscript𝑡𝑖𝜋𝑡t_{i}\sim\pi(t)italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_π ( italic_t ), ε𝒩(0,𝐈)similar-to𝜀𝒩0𝐈\varepsilon\sim\mathcal{N}\left(0,\mathbf{I}\right)italic_ε ∼ caligraphic_N ( 0 , bold_I )\triangleright Sample a timestep and noise
     z~tiα(ti)z0+σ(ti)εsubscript~𝑧subscript𝑡𝑖𝛼subscript𝑡𝑖subscript𝑧0𝜎subscript𝑡𝑖𝜀\tilde{z}_{t_{i}}\leftarrow\alpha(t_{i})\cdot z_{0}+\sigma(t_{i})\cdot\varepsilonover~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ← italic_α ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_σ ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ italic_ε
     for j=i10𝑗𝑖10j=i-1\rightarrow 0italic_j = italic_i - 1 → 0 do
         ε~=ωεϕteacher(z~tj+1,tj+1,c)+(1ω)εϕteacher(z~tj+1,tj+1,)~𝜀𝜔superscriptsubscript𝜀italic-ϕteachersubscript~𝑧subscript𝑡𝑗1subscript𝑡𝑗1𝑐1𝜔superscriptsubscript𝜀italic-ϕteachersubscript~𝑧subscript𝑡𝑗1subscript𝑡𝑗1\tilde{\varepsilon}=\omega\cdot\varepsilon_{\phi}^{\mathrm{teacher}}(\tilde{z}% _{t_{j+1}},t_{j+1},c)+(1-\omega)\cdot\varepsilon_{\phi}^{\mathrm{teacher}}(% \tilde{z}_{t_{j+1}},t_{j+1},\varnothing)over~ start_ARG italic_ε end_ARG = italic_ω ⋅ italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , italic_c ) + ( 1 - italic_ω ) ⋅ italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , ∅ )\triangleright CFG
         z~tjΨ(ε~,tj+1,z~tj+1)subscript~𝑧subscript𝑡𝑗Ψ~𝜀subscript𝑡𝑗1subscript~𝑧subscript𝑡𝑗1\tilde{z}_{t_{j}}\leftarrow\Psi(\tilde{\varepsilon},t_{j+1},\tilde{z}_{t_{j+1}})over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ← roman_Ψ ( over~ start_ARG italic_ε end_ARG , italic_t start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT )\triangleright ODE solver update
     end for
     z~0teacherz~t0superscriptsubscript~𝑧0teachersubscript~𝑧subscript𝑡0\tilde{z}_{0}^{\mathrm{teacher}}\leftarrow\tilde{z}_{t_{0}}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ← over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
     z~0student(z~tiσ(ti)εθstudent(z~ti))/α(ti)superscriptsubscript~𝑧0studentsubscript~𝑧subscript𝑡𝑖𝜎subscript𝑡𝑖superscriptsubscript𝜀𝜃studentsubscript~𝑧subscript𝑡𝑖𝛼subscript𝑡𝑖\tilde{z}_{0}^{\mathrm{student}}\leftarrow\big{(}\tilde{z}_{t_{i}}-\sigma(t_{i% })\cdot\varepsilon_{\theta}^{\mathrm{student}}(\tilde{z}_{t_{i}})\big{)}/% \alpha(t_{i})over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ← ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_σ ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ italic_ε start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ) / italic_α ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
     distil(z~0student,z~0teacher)subscriptdistilsuperscriptsubscript~𝑧0studentsuperscriptsubscript~𝑧0teacher\mathcal{L}\leftarrow\mathcal{L}_{\mathrm{distil}}(\tilde{z}_{0}^{\mathrm{% student}},\tilde{z}_{0}^{\mathrm{teacher}})caligraphic_L ← caligraphic_L start_POSTSUBSCRIPT roman_distil end_POSTSUBSCRIPT ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT , over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT ) + λadvadv(z~0student,z0)subscript𝜆advsubscriptadvsuperscriptsubscript~𝑧0studentsubscript𝑧0\lambda_{\mathrm{adv}}\cdot\mathcal{L}_{\mathrm{adv}}(\tilde{z}_{0}^{\mathrm{% student}},z_{0})italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT ⋅ caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + λdmdDMD(z~0student)subscript𝜆dmdsubscriptDMDsuperscriptsubscript~𝑧0student\lambda_{\mathrm{dmd}}\cdot\mathcal{L}_{\mathrm{DMD}}(\tilde{z}_{0}^{\mathrm{% student}})italic_λ start_POSTSUBSCRIPT roman_dmd end_POSTSUBSCRIPT ⋅ caligraphic_L start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT ( over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT )
end while

4.5 Model Training

While striving for robustness and versatility, we also aimed to design a model with a minimal number of trainable parameters, since it involves the loading of computationally intensive functions (teacher and student). To do so, we propose to rely on the parameter-efficient method LoRA [20] and apply it to our student model. This way, we drastically reduce the number of parameters and speed up the training process.

In a nutshell, our student model is trained to minimize a weighted combination of the distillation Eq. (6), the adversarial Eq. (8), and the distribution matching Eq. (9) losses:

=distil+λadvadv+λDMDDMD.subscriptdistilsubscript𝜆advsubscriptadvsubscript𝜆DMDsubscriptDMD\mathcal{L}=\mathcal{L}_{\mathrm{distil}}+\lambda_{\mathrm{adv}}\mathcal{L}_{% \mathrm{adv}}+\lambda_{\mathrm{DMD}}\mathcal{L}_{\mathrm{DMD}}\,.caligraphic_L = caligraphic_L start_POSTSUBSCRIPT roman_distil end_POSTSUBSCRIPT + italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT + italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT . (10)

The training process is detailed in Alg. 1 and illustrated in Figure 4. In more detail, we first pick a random sample x0p(x0)similar-tosubscript𝑥0𝑝subscript𝑥0x_{0}\sim p(x_{0})italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) belonging to the unknown data distribution. This sample is then encoded with an encoder \mathcal{E}caligraphic_E to get the corresponding latent sample z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. A timestep t𝑡titalic_t is drawn according to the timesteps probability mass function π𝜋\piitalic_π detailed in Sec. 4.2 to create a noisy sample ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT using Eq. (1). The teacher model εϕteachersuperscriptsubscript𝜀italic-ϕteacher\varepsilon_{\phi}^{\mathrm{teacher}}italic_ε start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT and the ODE solver ΨΨ\Psiroman_Ψ are then used to solve the PF-ODE and so generate a synthetic sample z~0teachersuperscriptsubscript~𝑧0teacher\tilde{z}_{0}^{\mathrm{teacher}}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT belonging to the distribution learned by the teacher model. At the same time, the student model fθstudentsuperscriptsubscript𝑓𝜃studentf_{\theta}^{\mathrm{student}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT is used to generate a denoised sample z~0student=fθstudent(zt,t)superscriptsubscript~𝑧0studentsuperscriptsubscript𝑓𝜃studentsubscript𝑧𝑡𝑡\tilde{z}_{0}^{\mathrm{student}}=f_{\theta}^{\mathrm{student}}(z_{t},t)over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_t ) in a single step. The distillation loss is then computed according to Eq. (6). Then, we re-noise the one-step student prediction z~0studentsuperscriptsubscript~𝑧0student\tilde{z}_{0}^{\mathrm{student}}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT as well as the input latent sample z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and compute the adversarial loss as explained in Sec. 4.3. Finally, for distribution matching, we take again the one-step student prediction z~0studentsuperscriptsubscript~𝑧0student\tilde{z}_{0}^{\mathrm{student}}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT and re-noise it using a uniformly sampled timestep t𝒰([0,1])similar-to𝑡𝒰01t\sim\mathcal{U}([0,1])italic_t ∼ caligraphic_U ( [ 0 , 1 ] ). The new noisy sample is passed through the teacher model to get the teacher score steachersuperscript𝑠teachers^{\mathrm{teacher}}italic_s start_POSTSUPERSCRIPT roman_teacher end_POSTSUPERSCRIPT function while we use the student model (and not a dedicated diffusion model as in [77]) to get the student score function sstudentsuperscript𝑠students^{\mathrm{student}}italic_s start_POSTSUPERSCRIPT roman_student end_POSTSUPERSCRIPT. The distribution matching loss is then computed as explained in Sec. 4.

Overall, our proposed method relies on the training of only a few number of parameters. This is achieved through applying LoRA to the student model, utilizing a frozen teacher model for the adversarial approach, and employing the student denoiser directly rather than introducing a new diffusion model to calculate the fake scores for the distribution matching loss. This approach not only drastically cuts down on the number of parameters but also accelerates the training process.

5 Experiments

In this section, we assess the effectiveness of our proposed method across various tasks and datasets. First, as it is common in the literature, we quantitatively compare the method with several approaches in the context of text-to-image generation. Then, we conduct an extensive ablation study to assess the importance and impact of each of the components proposed in the method. Finally, we highlight the versatility of our method across several tasks, conditioning, and denoiser architectures.

5.1 Text-to-Image Quantitative Evaluation

First, we propose to evaluate the proposed method against existing distillation approaches for text-to-image generation. In this section, we apply our distillation approach to the publicly available SD1.5 model [57] and report both FID [14] and CLIP score [51] on the COCO2014 and COCO2017 datasets [34]. The model is trained on the LAION dataset [63] where we select samples with aesthetic scores above 6 and re-caption the samples using CogVLM [71]. For COCO2017, we rely on the evaluation approach proposed in [45] and we pick 5,000 prompts from the validation set to generate synthetic images. For COCO2014, we employ the evaluation protocol proposed in [22] and pick 30,000 prompts from the validation set. We then compute the FID against the real images in the respective validation sets composed of 5,000 images for COCO2017 and 40,504 images for COCO2014. The model is trained for only 20k iterations on 2 H100-80Gb GPUs (amounting to a total 26 H100 equivalent hours of training) with a batch size of 4 and a learning rate of 1e51superscript𝑒51e^{-5}1 italic_e start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT together with Adam optimizer [26] for both the student and the discriminator. We use the timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) detailed in Sec. 4.2 with K=32𝐾32K=32italic_K = 32 and shift phases every 5000 iterations. We also start with both λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The guidance scale ω𝜔\omegaitalic_ω is sampled from 𝒰([3,13])𝒰313\mathcal{U}\big{(}[3,13]\big{)}caligraphic_U ( [ 3 , 13 ] ). The weights of the student model are all initialized with the teacher’s. See Appendix A.1 for more details on the training procedure.

We report the results in Tables 2 and 2. As illustrated in the tables, our method achieves a FID of 22.6 and 12.27 on COCO2017 and COCO2014 respectively with only 2 NFEs corresponding to SOTA results for few steps image generation. On COCO2017, our approach also achieves a CLIP score of 0.306 and 0.311 for 2 and 4 NFEs respectively. Importantly, our method only requires the training of 26.4M parameters (out-of the 900M teacher parameters) and merely 26 H100 GPUs hours of training time. This is in stark contrast with many competitors who depend on training the entire UNet architecture of the student, which involves hundreds of millions of parameters. Beyond quantitative assessments, we also offer a visual overview of the generated samples in Figure 5 for 1, 2, and 4 Network Function Evaluations. In addition to those results, we also provide a quantitative analysis with different backbones (SDXL and Pixart-α𝛼\alphaitalic_α) in Sec. 5.3.1.

Table 1: FID-5k and CLIP score on MS COCO2017 validation set for SD1.5 as teacher. We follow the procedure of [45] and pick the 5,000 prompt of the validation set.
Method # Train. NFE FID \downarrow CLIP \uparrow
Param. (M)
Teacher (SD1.5) N/A 50 20.1 0.318
16 31.7 0.320
Prog. Distil.[45] 900M 2 37.3 0.270
4 26.0 0.300
8 26.9 0.300
InstaFlow [36] 900M 1 23.4 0.304
CFG Distil. [30] 850M 16 24.2 0.300
Ours 26.4 M 2 22.6 0.306
4 22.5 0.311
Table 2: FID-30k on MS COCO2014 validation set for SD1.5 as teacher. We follow the evaluation procedure of [22]. results extracted from [77]
Method # Train. NFE FID \downarrow
Param. (M)
DPM++ [38] N/A 8 22.44
UniPC [82] N/A 8 23.30
UFOGen [76] 1,700M 1 12.78
InstaFlow [36] 900M 1 13.10
DMD [77] 1,700M 1 14.93
LCM-LoRA [43] 67.5M 1 77.90
2 24.28
4 23.62
Ours 26.4 M 2 12.27
4 12.41
Refer to caption
(a) 1 NFE
Refer to caption
(b) 2 NFEs
Refer to caption
(c) 4 NFEs
Figure 5: Qualitative evaluation of the sample quality as the number of NFEs increases for the proposed method applied to SD1.5 model. As expected, the quality of the samples increases with the number of NFEs inline with the quantitative results.

5.2 Ablation Study

In the following section, we conduct a comprehensive ablation study to assess the influence of the main parameters and choices made in the proposed method. Part of the ablated parameters are the distillation loss choice, the GAN loss choice, the impact of each term in the loss given in Eq. (10), the timestep sampling as well as the guidance scale used during training. For all the ablations, we train the model for 20k iterations with SD1.5 model as a teacher and employ the same hyper-parameters as in Sec. 5.1 unless stated otherwise. All the results are reported on the COCO2017 validation set using 2 NFEs.

Loss FID \downarrow CLIP \uparrow
distil.subscriptdistil\mathcal{L}_{\mathrm{distil.}}caligraphic_L start_POSTSUBSCRIPT roman_distil . end_POSTSUBSCRIPT 27.12 29.85
distil.+DMDsubscriptdistilsubscriptDMD\mathcal{L}_{\mathrm{distil.}}+\mathcal{L}_{\mathrm{DMD}}caligraphic_L start_POSTSUBSCRIPT roman_distil . end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT 26.88 30.45
distil.+advsubscriptdistilsubscriptadv\mathcal{L}_{\mathrm{distil.}}+\mathcal{L}_{\mathrm{adv}}caligraphic_L start_POSTSUBSCRIPT roman_distil . end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT 23.41 30.14
distil.+DMD+advsubscriptdistilsubscriptDMDsubscriptadv\mathcal{L}_{\mathrm{distil.}}+\mathcal{L}_{\mathrm{DMD}}+\mathcal{L}_{\mathrm% {adv}}caligraphic_L start_POSTSUBSCRIPT roman_distil . end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT 22.64 30.61
(a)
adv.subscriptadv\mathcal{L}_{\mathrm{adv.}}caligraphic_L start_POSTSUBSCRIPT roman_adv . end_POSTSUBSCRIPT FID \downarrow CLIP \uparrow
Hinge 25.02 30.17
WGAN 24.58 30.36
LSGAN 22.64 30.61
(b)
distil.subscriptdistil\mathcal{L}_{\mathrm{distil.}}caligraphic_L start_POSTSUBSCRIPT roman_distil . end_POSTSUBSCRIPT FID \downarrow CLIP \uparrow
LPIPS 24.89 30.56
MSE 22.64 30.61
(c)
Refer to caption
(d)
K𝐾Kitalic_K FID \downarrow CLIP \uparrow
16 23.35 30.52
32 22.64 30.61
64 22.87 30.41
(e)
π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) FID \downarrow CLIP \uparrow
πuniform(t)superscript𝜋uniform𝑡\pi^{\mathrm{uniform}}(t)italic_π start_POSTSUPERSCRIPT roman_uniform end_POSTSUPERSCRIPT ( italic_t ) 24.25 30.11
πgaussian(t)superscript𝜋gaussian𝑡\pi^{\mathrm{gaussian}}(t)italic_π start_POSTSUPERSCRIPT roman_gaussian end_POSTSUPERSCRIPT ( italic_t ) 35.89 28.15
πsharp(t)superscript𝜋sharp𝑡\pi^{\mathrm{sharp}}(t)italic_π start_POSTSUPERSCRIPT roman_sharp end_POSTSUPERSCRIPT ( italic_t ) 23.35 30.58
πours(t)superscript𝜋ours𝑡\pi^{\mathrm{ours}}(t)italic_π start_POSTSUPERSCRIPT roman_ours end_POSTSUPERSCRIPT ( italic_t ) 22.64 30.61
(f)
Figure 6: From left to right and top to bottom: Influence of the loss terms, the GAN loss choice, the distillation loss choice, the guidance scale used to generate with the teacher, K𝐾Kitalic_K the number of reference timesteps and the timestep sampling π(t)𝜋𝑡\pi(t)italic_π ( italic_t ),.
Influence of the loss terms

First, we assess the influence of each of the loss terms in the method’s loss given in Eq. (10). In this ablation study, we train the model using various combinations of loss terms: solely the distillation loss, the distillation loss combined with the DMD loss, the distillation loss paired with the adversarial loss, or incorporating all loss terms simultaneously. We report the results in the top left table in Figure 6. As highlighted in the table, the distillation loss alone leads to higher FID and lower CLIP score compared to the other configurations showcasing the importance of the adversarial and DMD losses. Moreover, both losses have a noticeable impact on the final performance since the adversarial loss seems to allow reaching a better image quality, as indicated by lower FID, while the DMD loss improves prompt adherence, reflected in higher CLIP scores. Experiments conducted using only adversarial and DMD losses revealed notable inconsistencies and even divergence in outcomes, emphasizing the crucial contribution of the distillation loss to the method’s stability and reliability. This highlights the importance of a balanced approach that incorporates all three losses for optimal results.

In the top middle and top right tables in Figure 6, we also report some results obtained with a model trained with different choices for the distillation loss and adversarial loss. For the distillation loss, we compare the use of LPIPS [81] and MSE. As highlighted in the table, the use of MSE allows to achieve better results in terms of FID and CLIP score than the use of LPIPS. For the GAN loss, we decide to compare the use of a Hinge loss [32], the approach of WGAN [1] or LSGAN [44]. As illustrated in the table, the use of LSGAN seems the best-suited choice. Moreover, we noticed stabler trainings using this loss than counterparts.

Influence of the timestep sampling

As explained in section 4.2, we select K=32𝐾32K=32italic_K = 32 uniformly spaced timesteps in the range [0,1]01[0,1][ 0 , 1 ] and assign a probability to each of them according to a probability mass function given by π(t)𝜋𝑡\pi(t)italic_π ( italic_t ). In this section, we highlight the influence of the choice for π(t)𝜋𝑡\pi(t)italic_π ( italic_t ). We compare the proposed timestep distribution described in Sec. 4.2 to the use of a uniform distribution across those 32 timesteps πuniform(t)=𝒰({i/K}i{1,,K})superscript𝜋uniform𝑡𝒰subscript𝑖𝐾𝑖1𝐾\pi^{\mathrm{uniform}}(t)=\mathcal{U}\left(\{i/K\}_{i\in\{1,\dots,K\}}\right)italic_π start_POSTSUPERSCRIPT roman_uniform end_POSTSUPERSCRIPT ( italic_t ) = caligraphic_U ( { italic_i / italic_K } start_POSTSUBSCRIPT italic_i ∈ { 1 , … , italic_K } end_POSTSUBSCRIPT ), a normal distribution πgaussian(t)superscript𝜋gaussian𝑡\pi^{\mathrm{gaussian}}(t)italic_π start_POSTSUPERSCRIPT roman_gaussian end_POSTSUPERSCRIPT ( italic_t ) centered on t=0.5𝑡0.5t=0.5italic_t = 0.5 and πsharpsuperscript𝜋sharp\pi^{\mathrm{sharp}}italic_π start_POSTSUPERSCRIPT roman_sharp end_POSTSUPERSCRIPT, a sharp version of our proposed distribution that only allows sampling 4 distinct timesteps. We report the results in the bottom right table in Figure 6. These distributions are represented in the Appendix A.1. The first outcome of such an experiment is that, as expected, the choice in the timestep sampling distribution has a noticeable impact on the final performance. As shown by the two bottom lines, the proposed distribution significantly improves the performance compared to the uniform and Gaussian distribution. Moreover, allowing to sample more than 4 distinct timesteps seems to be beneficial to the final performance since a noticeable decrease in the FID score is observed. This can be explained by the fact that the student model can distill more useful information from the teacher model by sampling a wider range of timesteps and not over-fit on the 4 selected ones.

In addition, we also assess the influence of the number of selected timesteps K𝐾Kitalic_K used to create the timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) described in Figure 3. We compare 3 cases: K=16𝐾16K=16italic_K = 16, K=32𝐾32K=32italic_K = 32 and K=64𝐾64K=64italic_K = 64 and report the results in the right-most table in Figure 6. In this experiment, K=32𝐾32K=32italic_K = 32 reveales to be a good trade-off between performance and computational cost. Indeed, the use of K=16𝐾16K=16italic_K = 16 seems to lead to a slight decrease in the final performance while K=64𝐾64K=64italic_K = 64 also involves a higher computational cost. It indeed induces a doubled number of NFEs used to produce a denoised sample with the teacher during training while the performances are roughly similar to K=32𝐾32K=32italic_K = 32.

Influence of the guidance scale

We also assess the influence of the guidance scale used during training to denoise the input samples with the teacher. For this ablation, unlike in Sec. 5.1, we generate samples from the teacher model using a fixed guidance scale ω𝜔\omegaitalic_ω set to either 1,3,5,7,10,13135710131,3,5,7,10,131 , 3 , 5 , 7 , 10 , 13 or 15151515. We report the evolution of the FID and CLIP score according to the guidance scale in the bottom left graph in Figure 6. As pictured in the figure and in line with the behavior observed with the teacher, the choice of the guidance scale has a strong impact on the final performance. While the CLIP score measuring prompt adherence tends to increase with the guidance scale, there exists a trade-off with the FID score that tends to eventually increase with the guidance scale resulting in a potential loss of image quality. We represent by the red dot the setting that we propose which consists in uniformly sampling a guidance scale within a given range ω𝒰([ωmin,ωmax])similar-to𝜔𝒰subscript𝜔subscript𝜔\omega\sim\mathcal{U}([\omega_{\min},\omega_{\max}])italic_ω ∼ caligraphic_U ( [ italic_ω start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT , italic_ω start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ] ) where ωminsubscript𝜔\omega_{\min}italic_ω start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT is set to 3 and ωmaxsubscript𝜔\omega_{\max}italic_ω start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT to 13. This approach reveals to be an interesting choice since it seems to allow the student model to be able to generate samples with a wider diversity.

5.3 On the Method’s Versatility

To highlight the versatility of the proposed method, we apply the same approach to diffusion models trained with different conditionings, different backbones, or using adapters [46]. In the following section, all the models are trained on the LAION dataset [63] with samples having an aesthetic score above 6 and re-captioned using CogVLM [71]. The student models are LoRAs [20] and share the same architecture as the teacher.

5.3.1 Backbones’ Study

Table 3: FID and CLIP score computed on the 10k first prompts of MS COCO2014 validation set for SDXL as teacher.
Method NFE FID \downarrow CLIP \uparrow
Teacher (SDXL) 40 18.42 0.339
LCM [42] 8 21.73 0.327
Turbo [61] 4 23.69 0.337
SDXL-lightning [33] 4 24.61 0.329
SDXL-lightning-LoRA [33] 4 25.13 0.328
Hyper-SD-LoRA [56] 4 27.76 0.333
Ours 4 21.62 0.327
Table 4: FID and CLIP score computed on the 10k first prompts of MS COCO2014 validation set for Pixart-α𝛼\alphaitalic_α as teacher.
Method NFE FID \downarrow CLIP \uparrow
Teacher (Pixart-α𝛼\alphaitalic_α) 40 28.09 0.316
Ours 4 29.30 0.303
Flash SDXL

In this section, we illustrate the ability of the proposed method to adapt to another more recent teacher model. To do so, we elect the publicly available SDXL model [50]. We train a LoRA student model (108M trainable parameters) sharing the same architecture as the teacher for 20k iterations on 4 H100-80Gb GPUs (amounting to a total of 176 H100 hours of training) with a batch size of 2 and a learning rate of 1e51superscript𝑒51e^{-5}1 italic_e start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT together with Adam optimizer [26] for both the student and the discriminator. We provide in Table 4 the FID and CLIP score computed on the 10k first prompts of COCO2014 validation set. We compare the proposed approach to several distillation methods proposed in the literature. The results for the competitors are computed using publicly available checkpoints. Our method can outperform peers in terms FID of 21.62 while maintaining quite good prompt alignment capabilities. In addition, we also provide a visual overview of the generated samples in Figure 7 for the teacher, the trained student model and LoRA-compatible approaches proposed in the literature (LCM [42], SDXL-lightning [33] and Hyper-SD [56]). Teacher samples are generated with a guidance scale of 5. For a fair comparison with competitors, we include prompts used in [33] for this qualitative evaluation. The proposed approach appears to be able to generate samples that are visually closer to the learned teacher distribution. In particular, HyperSD and lightning seem to struggle to generate samples that are realistic despite creating sharp samples. See the appendices for the comprehensive experimental setup used for this section and additional comparison.

Refer to caption
(a) Teacher
(40 NFEs)
Refer to caption
(b) LCM
(4 NFEs)
Refer to caption
(c) Lightning
(4 NFEs)
Refer to caption
(d) HyperSD
(4 NFEs)
Refer to caption
(e) Ours
(4 NFEs)

A photograph of a school bus in a magic forest

(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)
Refer to caption
(j)
Refer to caption
(k)

A monkey making latte art

(l)
Refer to caption
(m)
Refer to caption
(n)
Refer to caption
(o)
Refer to caption
(p)
Refer to caption
(q)

A majestic lion stands proudly on a rock, overlooking the vast African savannah

(r)
Figure 7: Application of Flash Diffusion to a SDXL. Teacher samples are generated with a guidance scale of 5. The proposed approach is compared to LoRA based competitors and appears to be able to generate samples that are visually closer to the learned teacher distribution. Best viewed zoomed in.
Refer to caption
(a) Teacher
(8 NFEs)
Refer to caption
(b) Teacher
(40 NFEs)
Refer to caption
(c) LCM
(4 NFEs)
Refer to caption
(d) Ours
(4 NFEs)

A whale with a big mouth and a rainbow on its back jum** out of the water

(e)
Refer to caption
(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)

A small cactus with a happy face in the Sahara desert

(j)
Refer to caption
(k)
Refer to caption
(l)
Refer to caption
(m)
Refer to caption
(n)

A close-up of a person with a shaved head, gazing downwards, with a hand resting on their forehead

(o)
Figure 8: Application of Flash Diffusion to a DiT-based Diffusion model (Pixart-α𝛼\alphaitalic_α). The proposed method 4 NFEs generations are compared to the teacher generations using 8 NFEs and 40 NFEs as well as Pixart-LCM [43] with 4 steps. Teacher samples are generated with a guidance scale of 3.
Flash Pixart (DiT)

In addition to UNet-based architecture, we propose in this section to apply the proposed method to a DiT denoiser backbone [49]. This appeared important to us since there exist only a few distillation methods that showed to be able to efficiently distill DiT models [43, 77] while these models have demonstrated very promising performances for text-to-image generations [5, 6, 11] or video generation [4]. To do so, we elect the Pixart-α𝛼\alphaitalic_α model [5] as teacher. We train a LoRA student model (66.5M trainable parameters) sharing the same architecture as the teacher for 40k iterations on 4 H100-80Gb GPUs (amounting to a total of 188 H100 hours of training) with a batch size of 2 and a learning rate of 1e51superscript𝑒51e^{-5}1 italic_e start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT together with Adam optimizer [26] for both the student and the discriminator. We compare the student generations using 4 NFEs to the teacher generations using 8 NFEs (4 steps) and 40 NFEs (20 steps) as well as Pixart-LCM [43] using 4 steps in Figure 8. We use a guidance scale of 3 to generate with the teacher. As illustrated in the figure, the proposed method can generate high-quality samples that sometimes seem even more visually appealing than the teacher. Moreover, driven by the adversarial approach the student model trained with our method generates images with more vivid colors and sharper details than LCM. It is noteworthy, that the student model does not loose the capability of the teacher to generate samples that are coherent with the prompt. In addition to those qualitative results, we provide in Table 4 the FID and CLIP score computed on the 10k first prompts of COCO2014 validation set for our model and the teacher. See the appendices for the comprehensive experimental setup and additional samples as well as discussion on the variability of the output samples with respect to the prompt

5.3.2 Conditionings’ Study

Refer to caption
(a) Original
Refer to caption
(b) Masked
Image
Refer to caption
(c) Teacher         (8 NFEs)
Refer to caption
(d) Teacher      (40 NFEs)
Refer to caption
(e) Ours
(4 NFEs)
Refer to caption
(f) LR image
Refer to caption
(g) Teacher         (8 NFEs)
Refer to caption
(h) Teacher         (40 NFEs)
Refer to caption
(i) Ours
(4 NFEs)
Refer to caption
(j)
Refer to caption
(k)
Refer to caption
(l)
Refer to caption
(m)
Refer to caption
(n)
Refer to caption
(o)
Refer to caption
(p)
Refer to caption
(q)
Refer to caption
(r)
Refer to caption
(s)
Refer to caption
(t)
Refer to caption
(u)
Refer to caption
(v)
Refer to caption
(w)
Refer to caption
(x)
Refer to caption
(y)
Refer to caption
(z)
Refer to caption
(aa)
Refer to caption
(ab)
Refer to caption
(ac)
Refer to caption
(ad)
Refer to caption
(ae)
Refer to caption
(af)
Refer to caption
(ag)
Refer to caption
(ah)
Refer to caption
(ai)
Refer to caption
(aj)
Figure 9: Application of Flash Diffusion to an in-house diffusion-based inpainting model (left) and in-house diffusion-based super-resolution model (right). Note that a single step for the teacher involves 2 NFEs since it uses CFG. Best viewed zoomed in.
Inpainting

First, we consider an in-house inpainting model whose backbone is common to SDXL [50]. The model is conditioned on both a masked image, a mask, and a prompt and is trained to reconstruct an input image from a masked one. Again, we initialize a student model with the weights of the teacher and train the student for 20k iterations. The main hyper-parameters are the same as in Sec. 5.1 unless the guidance that was sampled in 𝒰([5.0,10.0])𝒰5.010.0\mathcal{U}([5.0,10.0])caligraphic_U ( [ 5.0 , 10.0 ] ) and K𝐾Kitalic_K is set to 16 to speed up the training. We show some samples in Figure 9 (left) and compare the samples generated by the student model using 4 NFEs to the teacher generations using 4 steps (i.e. 8 NFEs) and 20 steps (i.e. 40 NFEs). In all cases, the models take as input the masked image and an associated prompt to generate the output image. As highlighted in the figure, the proposed method is able to generate samples that are visually close to the teacher generations while using far fewer NFEs demonstrating the ability of the method to adapt to different conditionings and tasks. See the appendices for the comprehensive experimental setup and additional samples.

Super-Resolution

To stress the method in cases where there is no text conditioning, we propose to consider an in-house diffusion model trained to upscale input images by a factor of 4. The model is conditioned on a low-resolution input image and is trained to generate an upscaled image. The degradation process used to artificially create the low-resolution image is similar to the one proposed in [72]. We show some generated samples using either the teacher model with 8 NFEs or 40 NFEs and the student model with 4 NFEs in Figure 9 (right). As highlighted in the figure, the proposed method can generate samples that are visually on par with the teacher generations while using far fewer NFEs. For this use case, tiling can be applied at inference time to generate images of resolution higher than the one used during training (1024×1024102410241024\times 10241024 × 1024). Nonetheless, the associated computational cost then scales quadratically with the image resolution becoming prohibitively expensive for the teacher. This emphasizes the usefulness of the proposed approach that reduces drastically the number of NFEs required to generate samples and so unlocks the possibility of generating high-resolution images. See the supplementary material for the comprehensive experimental setup and additional samples.

Face-Swap**
Refer to caption
(a) Source image
Refer to caption
(b) Target image
Refer to caption
(c) Teacher
(8 NFEs)
Refer to caption
(d) Teacher
(40 NFEs)
Refer to caption
(e) Ours
(4 NFEs)
Refer to caption
(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)
Refer to caption
(j)
Refer to caption
(k)
Refer to caption
(l)
Refer to caption
(m)
Refer to caption
(n)
Refer to caption
(o)
Refer to caption
(p)
Refer to caption
(q)
Refer to caption
(r)
Refer to caption
(s)
Refer to caption
(t)
Figure 10: Application of Flash Diffusion to an in-house diffusion-based face-swap** model. Best viewed zoomed in.

In addition to text-to-image, inpainting, and super-resolution, we also propose to apply the proposed method to an in-house face-swap** model. The model is conditioned on a source image and is trained to replace the face of the person in the target image with the one in the source image. In Figure 10, samples obtained using the teacher model with 8 NFEs or 40 NFEs are compared to the student model samples obtained with 4 NFEs. Interestingly, for this task, the proposed method can generate samples that seem to qualitatively outperform the teacher generation quality. The visual inspection of the results seems to indicate that the student appears to better respect the identity of the person in the source image and create more realistic output samples. This again highlights the ability of the method to adapt to different use cases. See the appendices for the comprehensive experimental setup.

Adapters

Finally, we show the compatibility of the proposed approach with adapters. For this experiment, we consider the canny and depth SDXL T2I adapters [46]. In these cases, the student model is trained to output samples conditioned on both a prompt and an additional conditioning given either by the edges of an input image or a depth map. The student model is trained using the proposed method and the same hyper-parameters as in Sec. 5.1 unless the guidance that was sampled in 𝒰([3.0,7.0])𝒰3.07.0\mathcal{U}([3.0,7.0])caligraphic_U ( [ 3.0 , 7.0 ] ) and K𝐾Kitalic_K is set to 16 to speed up the training. For both adapters, we use a conditioning scale of 0.8 to generate the samples with the student model. We show some samples in Figure 11.

Refer to caption
Refer to caption
(a) Canny T2I Adapters
Refer to caption
Refer to caption
(b) Depth Map T2I Adapters
Refer to caption
Refer to caption
(c)
Refer to caption
Refer to caption
(d)
Figure 11: Application of Flash Diffusion to canny and depth T2I adapters [46]. Samples are generated using 4 NFEs Best viewed zoomed in.

6 Conclusion

In this paper, we proposed a new versatile, fast, and efficient distillation method for diffusion models. The proposed method relies on the training of a student model to generate samples that are close to the data distribution learned by a teacher model using a combination of a distillation loss, an adversarial loss, and a distribution matching loss. We also proposed to rely on the LoRA method to reduce the number of training parameters and speed up the training process. We evaluated the proposed method on a text-to-image task and showed that it can achieve SOTA results on COCO2014 and COCO2017 datasets. We also stressed and illustrated the versatility of the method by applying it to several tasks (inpainting, super-resolution, face-swap**), different denoiser architectures (UNet, DiT), and adapters where the trained student model was able to produce high-quality samples using only a few numbers of NFEs. Future work would consists in trying to reduce even more the number of NFEs or trying to enhance the quality of the samples by applying Direct Preference Optimization [52, 70] directly to the student model.

References

  • Arjovsky et al. [2017] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
  • Blattmann et al. [2023a] Andreas Blattmann, Tim Dockhorn, Sumith Kulal, Daniel Mendelevitch, Maciej Kilian, Dominik Lorenz, Yam Levi, Zion English, Vikram Voleti, Adam Letts, et al. Stable video diffusion: Scaling latent video diffusion models to large datasets. arXiv preprint arXiv:2311.15127, 2023a.
  • Blattmann et al. [2023b] Andreas Blattmann, Robin Rombach, Huan Ling, Tim Dockhorn, Seung Wook Kim, Sanja Fidler, and Karsten Kreis. Align your latents: High-resolution video synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 22563–22575, 2023b.
  • Brooks et al. [2024] T Brooks, B Peebles, C Homes, W DePue, Y Guo, L **g, D Schnurr, J Taylor, T Luhman, E Luhman, et al. Video generation models as world simulators, 2024.
  • Chen et al. [2023] Junsong Chen, YU **cheng, GE Chongjian, Lewei Yao, Enze Xie, Zhongdao Wang, James Kwok, ** Luo, Huchuan Lu, and Zhenguo Li. Pixart-α𝛼\alphaitalic_α: Fast training of diffusion transformer for photorealistic text-to-image synthesis. In The Twelfth International Conference on Learning Representations, 2023.
  • Chen et al. [2024] Junsong Chen, Chongjian Ge, Enze Xie, Yue Wu, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, ** Luo, Huchuan Lu, and Zhenguo Li. Pixart-σ𝜎\sigmaitalic_σ: Weak-to-strong training of diffusion transformer for 4k text-to-image generation. arXiv preprint arXiv:2403.04692, 2024.
  • Chen et al. [2018] Ricky TQ Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. Advances in neural information processing systems, 31, 2018.
  • Dhariwal and Nichol [2021] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34:8780–8794, 2021.
  • Dziugaite et al. [2015] Gintare Karolina Dziugaite, Daniel M Roy, and Zoubin Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In Proceedings of the Thirty-First Conference on Uncertainty in Artificial Intelligence, pages 258–267, 2015.
  • Esser et al. [2023] Patrick Esser, Johnathan Chiu, Parmida Atighehchian, Jonathan Granskog, and Anastasis Germanidis. Structure and content-guided video synthesis with diffusion models. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 7346–7356, 2023.
  • Esser et al. [2024] Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, et al. Scaling rectified flow transformers for high-resolution image synthesis. arXiv preprint arXiv:2403.03206, 2024.
  • Goodfellow et al. [2014] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems, pages 2672–2680, 2014.
  • Hendrycks and Gimpel [2016] Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415, 2016.
  • Heusel et al. [2017] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in neural information processing systems, 30, 2017.
  • Hinton et al. [2015] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • Ho and Salimans [2021] Jonathan Ho and Tim Salimans. Classifier-free diffusion guidance. In NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021.
  • Ho et al. [2020] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840–6851, 2020.
  • Ho et al. [2022a] Jonathan Ho, William Chan, Chitwan Saharia, Jay Whang, Ruiqi Gao, Alexey Gritsenko, Diederik P Kingma, Ben Poole, Mohammad Norouzi, David J Fleet, et al. Imagen video: High definition video generation with diffusion models. arXiv preprint arXiv:2210.02303, 2022a.
  • Ho et al. [2022b] Jonathan Ho, William Chan, Chitwan Saharia, Jay Whang, Ruiqi Gao, Alexey Gritsenko, Diederik P Kingma, Ben Poole, Mohammad Norouzi, David J Fleet, et al. Imagen video: High definition video generation with diffusion models. arXiv preprint arXiv:2210.02303, 2022b.
  • Hu et al. [2021] Edward J Hu, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen, et al. Lora: Low-rank adaptation of large language models. In International Conference on Learning Representations, 2021.
  • Ilharco et al. [2021] Gabriel Ilharco, Mitchell Wortsman, Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, Hongseok Namkoong, John Miller, Hannaneh Hajishirzi, Ali Farhadi, and Ludwig Schmidt. Openclip. 2021. URL https://doi.org/10.5281/zenodo.5143773.
  • Kang et al. [2023] Minguk Kang, Jun-Yan Zhu, Richard Zhang, Jaesik Park, Eli Shechtman, Sylvain Paris, and Taesung Park. Scaling up gans for text-to-image synthesis. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 10124–10134, 2023.
  • Karras et al. [2022] Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35:26565–26577, 2022.
  • Kim et al. [2023] Dongjun Kim, Chieh-Hsin Lai, Wei-Hsiang Liao, Naoki Murata, Yuhta Takida, Toshimitsu Uesaka, Yutong He, Yuki Mitsufuji, and Stefano Ermon. Consistency trajectory models: Learning probability flow ode trajectory of diffusion. In The Twelfth International Conference on Learning Representations, 2023.
  • Kingma et al. [2021] Diederik Kingma, Tim Salimans, Ben Poole, and Jonathan Ho. Variational diffusion models. Advances in neural information processing systems, 34:21696–21707, 2021.
  • Kingma and Ba [2014] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kingma and Welling [2014] Diederik P. Kingma and Max Welling. Auto-encoding variational bayes. arXiv:1312.6114 [cs, stat], 2014.
  • Kohler et al. [2024] Jonas Kohler, Albert Pumarola, Edgar Schönfeld, Artsiom Sanakoyeu, Roshan Sumbaly, Peter Vajda, and Ali Thabet. Imagine flash: Accelerating emu diffusion models with backward distillation. arXiv preprint arXiv:2405.05224, 2024.
  • Kong et al. [2020] Zhifeng Kong, Wei **, Jiaji Huang, Kexin Zhao, and Bryan Catanzaro. Diffwave: A versatile diffusion model for audio synthesis. In International Conference on Learning Representations, 2020.
  • Li et al. [2024] Yanyu Li, Huan Wang, Qing **, Ju Hu, Pavlo Chemerys, Yun Fu, Yanzhi Wang, Sergey Tulyakov, and Jian Ren. Snapfusion: Text-to-image diffusion model on mobile devices within two seconds. Advances in Neural Information Processing Systems, 36, 2024.
  • Li et al. [2015] Yujia Li, Kevin Swersky, and Rich Zemel. Generative moment matching networks. In International conference on machine learning, pages 1718–1727. PMLR, 2015.
  • Lim and Ye [2017] Jae Hyun Lim and Jong Chul Ye. Geometric gan. arXiv preprint arXiv:1705.02894, 2017.
  • Lin et al. [2024] Shanchuan Lin, Anran Wang, and Xiao Yang. Sdxl-lightning: Progressive adversarial diffusion distillation. arXiv preprint arXiv:2402.13929, 2024.
  • Lin et al. [2014] Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C Lawrence Zitnick. Microsoft coco: Common objects in context. In Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part V 13, pages 740–755. Springer, 2014.
  • Liu et al. [2022] Xingchao Liu, Chengyue Gong, et al. Flow straight and fast: Learning to generate and transfer data with rectified flow. In The Eleventh International Conference on Learning Representations, 2022.
  • Liu et al. [2023] Xingchao Liu, Xiwen Zhang, Jianzhu Ma, Jian Peng, et al. Instaflow: One step is enough for high-quality diffusion-based text-to-image generation. In The Twelfth International Conference on Learning Representations, 2023.
  • Lu et al. [2022a] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. Advances in Neural Information Processing Systems, 35:5775–5787, 2022a.
  • Lu et al. [2022b] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver++: Fast solver for guided sampling of diffusion probabilistic models. arXiv preprint arXiv:2211.01095, 2022b.
  • Luhman and Luhman [2021] Eric Luhman and Troy Luhman. Knowledge distillation in iterative generative models for improved sampling speed. arXiv preprint arXiv:2101.02388, 2021.
  • Luo and Hu [2021a] Shitong Luo and Wei Hu. Diffusion probabilistic models for 3d point cloud generation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2837–2845, 2021a.
  • Luo and Hu [2021b] Shitong Luo and Wei Hu. Score-based point cloud denoising. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4583–4592, 2021b.
  • Luo et al. [2023a] Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao. Latent consistency models: Synthesizing high-resolution images with few-step inference. arXiv preprint arXiv:2310.04378, 2023a.
  • Luo et al. [2023b] Simian Luo, Yiqin Tan, Suraj Patil, Daniel Gu, Patrick von Platen, Apolinário Passos, Longbo Huang, Jian Li, and Hang Zhao. Lcm-lora: A universal stable-diffusion acceleration module. arXiv preprint arXiv:2311.05556, 2023b.
  • Mao et al. [2017] Xudong Mao, Qing Li, Haoran Xie, Raymond YK Lau, Zhen Wang, and Stephen Paul Smolley. Least squares generative adversarial networks. In Proceedings of the IEEE international conference on computer vision, pages 2794–2802, 2017.
  • Meng et al. [2023] Chenlin Meng, Robin Rombach, Ruiqi Gao, Diederik Kingma, Stefano Ermon, Jonathan Ho, and Tim Salimans. On distillation of guided diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 14297–14306, 2023.
  • Mou et al. [2024] Chong Mou, Xintao Wang, Liangbin Xie, Yanze Wu, Jian Zhang, Zhongang Qi, and Ying Shan. T2i-adapter: Learning adapters to dig out more controllable ability for text-to-image diffusion models. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 38, pages 4296–4304, 2024.
  • Nichol et al. [2022] Alexander Quinn Nichol, Prafulla Dhariwal, Aditya Ramesh, Pranav Shyam, Pamela Mishkin, Bob Mcgrew, Ilya Sutskever, and Mark Chen. Glide: Towards photorealistic image generation and editing with text-guided diffusion models. In International Conference on Machine Learning, pages 16784–16804. PMLR, 2022.
  • Parmar et al. [2022] Gaurav Parmar, Richard Zhang, and Jun-Yan Zhu. On aliased resizing and surprising subtleties in gan evaluation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 11410–11420, 2022.
  • Peebles and Xie [2023] William Peebles and Saining Xie. Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4195–4205, 2023.
  • Podell et al. [2023] Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach. Sdxl: Improving latent diffusion models for high-resolution image synthesis. In The Twelfth International Conference on Learning Representations, 2023.
  • Radford et al. [2021] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. In International conference on machine learning, pages 8748–8763. PMLR, 2021.
  • Rafailov et al. [2024] Rafael Rafailov, Archit Sharma, Eric Mitchell, Christopher D Manning, Stefano Ermon, and Chelsea Finn. Direct preference optimization: Your language model is secretly a reward model. Advances in Neural Information Processing Systems, 36, 2024.
  • Ramachandran et al. [2017] Prajit Ramachandran, Barret Zoph, and Quoc V Le. Searching for activation functions. arXiv preprint arXiv:1710.05941, 2017.
  • Ramesh et al. [2021] Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. Zero-shot text-to-image generation. In International conference on machine learning, pages 8821–8831. Pmlr, 2021.
  • Ramesh et al. [2022] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, and Mark Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 2022.
  • Ren et al. [2024] Yuxi Ren, Xin Xia, Yanzuo Lu, Jiacheng Zhang, Jie Wu, Pan Xie, Xing Wang, and Xuefeng Xiao. Hyper-sd: Trajectory segmented consistency model for efficient image synthesis. arXiv preprint arXiv:2404.13686, 2024.
  • Rombach et al. [2022] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10684–10695, 2022.
  • Ronneberger et al. [2015] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In Medical image computing and computer-assisted intervention–MICCAI 2015: 18th international conference, Munich, Germany, October 5-9, 2015, proceedings, part III 18, pages 234–241. Springer, 2015.
  • Saharia et al. [2022] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily L Denton, Kamyar Ghasemipour, Raphael Gontijo Lopes, Burcu Karagol Ayan, Tim Salimans, et al. Photorealistic text-to-image diffusion models with deep language understanding. Advances in neural information processing systems, 35:36479–36494, 2022.
  • Salimans and Ho [2021] Tim Salimans and Jonathan Ho. Progressive distillation for fast sampling of diffusion models. In International Conference on Learning Representations, 2021.
  • Sauer et al. [2023] Axel Sauer, Dominik Lorenz, Andreas Blattmann, and Robin Rombach. Adversarial diffusion distillation. arXiv preprint arXiv:2311.17042, 2023.
  • Sauer et al. [2024] Axel Sauer, Frederic Boesel, Tim Dockhorn, Andreas Blattmann, Patrick Esser, and Robin Rombach. Fast high-resolution image synthesis with latent adversarial diffusion distillation. arXiv preprint arXiv:2403.12015, 2024.
  • Schuhmann et al. [2022] Christoph Schuhmann, Romain Beaumont, Richard Vencu, Cade Gordon, Ross Wightman, Mehdi Cherti, Theo Coombes, Aarush Katta, Clayton Mullis, Mitchell Wortsman, et al. Laion-5b: An open large-scale dataset for training next generation image-text models. Advances in Neural Information Processing Systems, 35:25278–25294, 2022.
  • Sohl-Dickstein et al. [2015] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In International conference on machine learning, pages 2256–2265. PMLR, 2015.
  • Song and Dhariwal [2023] Yang Song and Prafulla Dhariwal. Improved techniques for training consistency models. In The Twelfth International Conference on Learning Representations, 2023.
  • Song and Ermon [2019] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32, 2019.
  • Song et al. [2020] Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2020.
  • Song et al. [2023] Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. Consistency models. In Proceedings of the 40th International Conference on Machine Learning, pages 32211–32252, 2023.
  • Vincent [2011] Pascal Vincent. A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674, 2011.
  • Wallace et al. [2023] Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik. Diffusion model alignment using direct preference optimization. arXiv preprint arXiv:2311.12908, 2023.
  • Wang et al. [2023] Weihan Wang, Qingsong Lv, Wenmeng Yu, Wenyi Hong, Ji Qi, Yan Wang, Junhui Ji, Zhuoyi Yang, Lei Zhao, Xixuan Song, et al. Cogvlm: Visual expert for pretrained language models. arXiv preprint arXiv:2311.03079, 2023.
  • Wang et al. [2021] Xintao Wang, Liangbin Xie, Chao Dong, and Ying Shan. Real-esrgan: Training real-world blind super-resolution with pure synthetic data. In Proceedings of the IEEE/CVF international conference on computer vision, pages 1905–1914, 2021.
  • Wang et al. [2024] Zhengyi Wang, Cheng Lu, Yikai Wang, Fan Bao, Chongxuan Li, Hang Su, and Jun Zhu. Prolificdreamer: High-fidelity and diverse text-to-3d generation with variational score distillation. Advances in Neural Information Processing Systems, 36, 2024.
  • Wu et al. [2023] Lemeng Wu, Dilin Wang, Chengyue Gong, Xingchao Liu, Yunyang Xiong, Rakesh Ranjan, Raghuraman Krishnamoorthi, Vikas Chandra, and Qiang Liu. Fast point cloud generation with straight flows. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 9445–9454, 2023.
  • Wu and He [2018] Yuxin Wu and Kaiming He. Group normalization. In Proceedings of the European conference on computer vision (ECCV), pages 3–19, 2018.
  • Xu et al. [2023] Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. Ufogen: You forward once large scale text-to-image generation via diffusion gans. arXiv preprint arXiv:2311.09257, 2023.
  • Yin et al. [2023] Tianwei Yin, Michaël Gharbi, Richard Zhang, Eli Shechtman, Fredo Durand, William T Freeman, and Taesung Park. One-step diffusion with distribution matching distillation. arXiv preprint arXiv:2311.18828, 2023.
  • Yin et al. [2024] Tianwei Yin, Michaël Gharbi, Taesung Park, Richard Zhang, Eli Shechtman, Fredo Durand, and William T Freeman. Improved distribution matching distillation for fast image synthesis. arXiv preprint arXiv:2405.14867, 2024.
  • Zhang et al. [2023] Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. Adding conditional control to text-to-image diffusion models. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 3836–3847, 2023.
  • Zhang and Chen [2022] Qinsheng Zhang and Yongxin Chen. Fast sampling of diffusion models with exponential integrator. In NeurIPS 2022 Workshop on Score-Based Methods, 2022.
  • Zhang et al. [2018] Richard Zhang, Phillip Isola, Alexei A Efros, Eli Shechtman, and Oliver Wang. The unreasonable effectiveness of deep features as a perceptual metric. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 586–595, 2018.
  • Zhao et al. [2024] Wenliang Zhao, Lujia Bai, Yongming Rao, Jie Zhou, and Jiwen Lu. Unipc: A unified predictor-corrector framework for fast sampling of diffusion models. Advances in Neural Information Processing Systems, 36, 2024.
  • Zheng et al. [2023] Hongkai Zheng, Weili Nie, Arash Vahdat, Kamyar Azizzadenesheli, and Anima Anandkumar. Fast sampling of diffusion models via operator learning. In International Conference on Machine Learning, pages 42390–42402. PMLR, 2023.

Appendix A Experimental Details

A.1 Experimental Setup for Text-to-Image

To compute the FID, we rely on the clean-fid library [48] while we use an OpenCLIP-G backbone [21] to compute the CLIP scores. The models are trained on the LAION dataset [63] where we select samples with aesthetic scores above 6 and re-caption the samples using CogVLM [71].

Flash SD1.5

In this section, we provide the detailed experimental setup used to perform the quantitative evaluation of the model described in Sec. 5.1. For this experiment, we use SD1.5 model as teacher and initialize the student with SD1.5’s weights. The student model is trained for 20k iterations on 2 H100-80Gb GPUs (amounting to 26 H100 hours of training) with a batch size of 4 and a learning rate of 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT for both the student and the discriminator. We use the timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) detailed in Sec. 4.2 with K=32𝐾32K=32italic_K = 32 and shift modes every 5000 iterations. We also start with both λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The schedule is [0,0.1,0.2,0.3]00.10.20.3[0,0.1,0.2,0.3][ 0 , 0.1 , 0.2 , 0.3 ] for λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and [0,0.3,0.5,0.7]00.30.50.7[0,0.3,0.5,0.7][ 0 , 0.3 , 0.5 , 0.7 ] for λDMDsubscript𝜆DMD\lambda_{\mathrm{DMD}}italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT. The guidance scale ω𝜔\omegaitalic_ω used to denoise using the teacher model is uniformly sampled from [3,13]313[3,13][ 3 , 13 ]. The distillation loss is set to the MSE loss and the GAN loss is set to the LSGAN loss.

When ablating the timesteps distribution, we use the following distributions: πuniform(t)superscript𝜋uniform𝑡\pi^{\mathrm{uniform}}(t)italic_π start_POSTSUPERSCRIPT roman_uniform end_POSTSUPERSCRIPT ( italic_t ), πgaussian(t)superscript𝜋gaussian𝑡\pi^{\mathrm{gaussian}}(t)italic_π start_POSTSUPERSCRIPT roman_gaussian end_POSTSUPERSCRIPT ( italic_t ), πsharp(t)superscript𝜋sharp𝑡\pi^{\mathrm{sharp}}(t)italic_π start_POSTSUPERSCRIPT roman_sharp end_POSTSUPERSCRIPT ( italic_t ) and πours(t)superscript𝜋ours𝑡\pi^{\mathrm{ours}}(t)italic_π start_POSTSUPERSCRIPT roman_ours end_POSTSUPERSCRIPT ( italic_t ) that are represented in Figure 12.

Refer to caption
(a) (a) πuniformsuperscript𝜋uniform\pi^{\mathrm{uniform}}italic_π start_POSTSUPERSCRIPT roman_uniform end_POSTSUPERSCRIPT
Refer to caption
(b) (b) πgaussiansuperscript𝜋gaussian\pi^{\mathrm{gaussian}}italic_π start_POSTSUPERSCRIPT roman_gaussian end_POSTSUPERSCRIPT

[Warm-up]Refer to caption [Phase 1]Refer to caption [Phase 2]Refer to caption [ Phase 3]Refer to caption

(c) (c) πsharpsuperscript𝜋sharp\pi^{\mathrm{sharp}}italic_π start_POSTSUPERSCRIPT roman_sharp end_POSTSUPERSCRIPT

[Warm-up]Refer to caption [Phase 1]Refer to caption [Phase 2]Refer to caption [ Phase 3]Refer to caption

(d) (d) πourssuperscript𝜋ours\pi^{\mathrm{ours}}italic_π start_POSTSUPERSCRIPT roman_ours end_POSTSUPERSCRIPT
Figure 12: Illustration of the timestep distributions used in the ablation study in Sec 5.2
Flash SDXL

In this section, we train a student model sharing the same UNet architecture as SDXL. The model is trained for 20k iterations on 4 H100-80Gb GPUs with a batch size of 2 and a learning rate of 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT for both the student and the discriminator. The student weights are initialized with the teacher’s one. The timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) is detailed in Sec. 4.2 and chosen such that K=32𝐾32K=32italic_K = 32. We also shift modes every 5000 iterations. As for SD1.5, we set λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The schedule is [0,0.1,0.2,0.3]00.10.20.3[0,0.1,0.2,0.3][ 0 , 0.1 , 0.2 , 0.3 ] for λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and [0,0.3,0.5,0.7]00.30.50.7[0,0.3,0.5,0.7][ 0 , 0.3 , 0.5 , 0.7 ] for λDMDsubscript𝜆DMD\lambda_{\mathrm{DMD}}italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT. We use a guidance scale ω𝜔\omegaitalic_ω uniformly sampled from [3,13]313[3,13][ 3 , 13 ] with a distillation loss chosen as LPIPS and the GAN loss is set to the LSGAN loss.

Flash Pixart (DiT)

We train a student model sharing the same architecture as the teacher for 40k iterations on 4 H100-80Gb GPUs with a batch size of 2 and a learning rate of 1e51superscript𝑒51e^{-5}1 italic_e start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT together with Adam optimizer [26] for both the student and the discriminator. The weights of the student model are initialized using the teacher’s. We use the timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) detailed in Sec. 4.2 with K=16𝐾16K=16italic_K = 16 and shift modes every 10000 iterations. We also start with both λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The schedule is [0,0.05,0.1,0.2]00.050.10.2[0,0.05,0.1,0.2][ 0 , 0.05 , 0.1 , 0.2 ] for λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and [0,0.3,0.5,0.7]00.30.50.7[0,0.3,0.5,0.7][ 0 , 0.3 , 0.5 , 0.7 ] for λDMDsubscript𝜆DMD\lambda_{\mathrm{DMD}}italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT. The guidance scale ω𝜔\omegaitalic_ω used to denoise using the teacher model is uniformly sampled from [2,9]29[2,9][ 2 , 9 ]. The distillation loss is LPIPS loss and the GAN loss is set as the LSGAN loss.

A.2 Experimental Setup for Inpainting

For the inpainting experiment, we use an in-house diffusion-based model whose backbone architecture is similar to the one of SDXL [50] and weights are initialized using the teacher. The student model is trained on 512x512 input image resolution for 20k iterations on 2 H100-80Gb GPUs with a batch size of 4 and a learning rate of 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT for both the student and the discriminator. The timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) is detailed in Sec. 4.2 and chosen with K=16𝐾16K=16italic_K = 16. Modes are shifted every 5000 iterations. We again start with both λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The schedule is [0,0.1,0.2,0.3]00.10.20.3[0,0.1,0.2,0.3][ 0 , 0.1 , 0.2 , 0.3 ] for λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and [0,0.3,0.5,0.7]00.30.50.7[0,0.3,0.5,0.7][ 0 , 0.3 , 0.5 , 0.7 ] for λDMDsubscript𝜆DMD\lambda_{\mathrm{DMD}}italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT. The guidance scale ω𝜔\omegaitalic_ω is uniformly sampled from [3,13]313[3,13][ 3 , 13 ]. The distillation loss is set as the MSE loss and the GAN loss is set as the LSGAN loss.

A.3 Experimental Setup for Super-Resolution

For the super-resolution experiment, we use an in-house diffusion-based model whose backbone architecture is similar to the one of SDXL [50]. The student model is trained with 256x256 low-resolution images used as conditioning and output 1024x1024 images. The student model is initialized using the teacher’s weights and is trained for 20k iterations on 2 H100-80Gb GPUs with a batch size of 4 and a learning rate of 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT for both the student and the discriminator. We set K=16𝐾16K=16italic_K = 16 for π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) and shift modes every 5000 iterations. We start with λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The schedule is [0,0.1,0.2,0.3]00.10.20.3[0,0.1,0.2,0.3][ 0 , 0.1 , 0.2 , 0.3 ] for λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and [0,0.3,0.5,0.7]00.30.50.7[0,0.3,0.5,0.7][ 0 , 0.3 , 0.5 , 0.7 ] for λDMDsubscript𝜆DMD\lambda_{\mathrm{DMD}}italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT. The guidance scale ω𝜔\omegaitalic_ω used to denoise using the teacher model is uniformly sampled from [1.2,1.8]1.21.8[1.2,1.8][ 1.2 , 1.8 ]. The distillation loss is set as the MSE loss and the GAN loss is chosen as the LSGAN loss.

A.4 Experimental Setup for Face-Swap**

For the face-swap** experiment, we use an in-house diffusion-based model whose backbone architecture is similar to the one of SD2.2 [57]. The student model is trained on 512x512 input images and target images. We use a face detector to extract the face from the source image and use it as conditioning. The student model is then initialized using the teacher’s weights and is trained for 15k iterations on 2 H100-80Gb GPUs with a batch size of 8 and a learning rate of 105superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT for both the student and the discriminator. We use the timestep distribution π(t)𝜋𝑡\pi(t)italic_π ( italic_t ) detailed in Sec. 4.2 with K=16𝐾16K=16italic_K = 16 and shift modes every 5000 iterations. We also start with both λadv=0subscript𝜆adv0\lambda_{\mathrm{adv}}=0italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT = 0 and λDMD=0subscript𝜆DMD0\lambda_{\mathrm{DMD}}=0italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT = 0 and progressively increase each time we change the timestep distribution so they reach final values set to 0.30.30.30.3 and 0.70.70.70.7 respectively. The schedule is [0,0.1,0.2,0.3]00.10.20.3[0,0.1,0.2,0.3][ 0 , 0.1 , 0.2 , 0.3 ] for λadvsubscript𝜆adv\lambda_{\mathrm{adv}}italic_λ start_POSTSUBSCRIPT roman_adv end_POSTSUBSCRIPT and [0,0.3,0.5,0.7]00.30.50.7[0,0.3,0.5,0.7][ 0 , 0.3 , 0.5 , 0.7 ] for λDMDsubscript𝜆DMD\lambda_{\mathrm{DMD}}italic_λ start_POSTSUBSCRIPT roman_DMD end_POSTSUBSCRIPT. The guidance scale ω𝜔\omegaitalic_ω used to denoise using the teacher model is uniformly sampled from [2.0,7.0]2.07.0[2.0,7.0][ 2.0 , 7.0 ]. The distillation loss is set as the MSE loss and the GAN loss is chosen as the LSGAN loss.

Appendix B Additional Sampling Results

In this section, we provide additional samples for each task considered in the main paper.

B.1 Flash SDXL

In Figure 13, we provide addition samples enriching the qualitative comparision performed in the main manuscript. Again, to be fair to the competitors, we use some prompts from [33] to generate the samples. As mentioned in the paper, the proposed approach appears to be able to generate samples that are visually closer to the learned teacher distribution.

Refer to caption
(a) Teacher
(40 NFEs)
Refer to caption
(b) LCM
(4 NFEs)
Refer to caption
(c) Lightning
(4 NFEs)
Refer to caption
(d) HyperSD
(4 NFEs)
Refer to caption
(e) Ours
(4 NFEs)

A pickup truck going up a mountain switchback

(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)
Refer to caption
(j)
Refer to caption
(k)

A giant wave breaking on a majestic lighthouse

(l)
Refer to caption
(m)
Refer to caption
(n)
Refer to caption
(o)
Refer to caption
(p)
Refer to caption
(q)

An Asian firefighter with a rugged jawline rushes through the billowing smoke of an autumn blaze

(r)
Refer to caption
(s)
Refer to caption
(t)
Refer to caption
(u)
Refer to caption
(v)
Refer to caption
(w)

Cute cartoon small cat sitting in a movie theater eating popcorn, watching a movie

(x)
Refer to caption
(y)
Refer to caption
(z)
Refer to caption
(aa)
Refer to caption
(ab)
Refer to caption
(ac)

A very realistic close up of an old elderly man with green eyes looking straight at the camera, vivid colors

(ad)
Refer to caption
(ae)
Refer to caption
(af)
Refer to caption
(ag)
Refer to caption
(ah)
Refer to caption
(ai)

A delicate porcelain teacup sits on a saucer, its surface adorned with intricate blue patterns

(aj)
Figure 13: Application of Flash Diffusion to a SDXL teacher model. The proposed method 4 NFEs generations are compared to the teacher generations using 40 NFEs as well as LoRA approaches proposed in the literature (LCM [42], SDXL-lightning [33] and Hyper-SD [56]). Teacher samples are generated with a guidance scale of 5. Best viewed zoomed in.

B.2 Flash Pixart (DiT)

In this section, we provide additional samples using the trained student model using a DiT architecture. In Figure 14, we provide a more complete qualitative comparison with respect to LCM and the teacher model while in Figures 15 and 16, we show additional samples using the proposed method. In Figures 17 and 18, we also show the generation variation with respect to two different prompts: A yellow orchid trapped inside an empty bottle of wine and An oil painting portrait of an elegant blond woman with a bowtie and hat. The model appears to be able to generate various samples even with a fixed prompt.

Refer to caption
(a) Teacher
(8 NFEs)
Refer to caption
(b) Teacher
(40 NFEs)
Refer to caption
(c) LCM
(4 NFEs)
Refer to caption
(d) Ours
(4 NFEs)

A cute cheetah looking amazed and surprised

(e)
Refer to caption
(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)

A giant wave shoring on big red lighthouse

(j)
Refer to caption
(k)
Refer to caption
(l)
Refer to caption
(m)
Refer to caption
(n)

A raccoon reading a book in a lush forest

(o)
Refer to caption
(p)
Refer to caption
(q)
Refer to caption
(r)
Refer to caption
(s)

A classic turquoise car is parked outside a modern building with curved balconies

(t)
Refer to caption
(u)
Refer to caption
(v)
Refer to caption
(w)
Refer to caption
(x)

A beautiful sunflower in rainy day

(y)
Refer to caption
(z)
Refer to caption
(aa)
Refer to caption
(ab)
Refer to caption
(ac)

A woman in a red traditional outfit wields a sword, poised in an intense stance against a dark background

(ad)
Figure 14: Application of Flash Diffusion to a DiT-based Diffusion model, namely Pixart-α𝛼\alphaitalic_α. The proposed method 4 NFEs generations are compared to the teacher generations using 8 NFEs and 40 NFEs as well as Pixart-LCM [43] with 4 steps. Teacher samples are generated with a guidance scale of 3.
Refer to caption
(a) A famous professor giraffe in a classroom standing in front of the blackboard teaching
Refer to caption
(b) A close up of an old elderly man with green eyes looking straight at the camera
Refer to caption
(c) A cute fluffy rabbit pilot walking on a military aircraft carrier, 8k, cinematic
Refer to caption
(d) Pirate ship sailing on a sea with the milky way galaxy in the sky and purple glow lights
Figure 15: Application of Flash Diffusion to a DiT-based Diffusion model Pixart-α𝛼\alphaitalic_α.
Refer to caption
(a) A photograph of a woman with headphone coding on a computer, photograph, cinematic, high details, 4k
Refer to caption
(b) A super realistic kungfu master panda Japanese style
Refer to caption
(c) The scene represents a desert composed of red rock resembling planet Mars, there is a cute robot with big eyes feeling alone, It looks straight to the camera looking for friends
Refer to caption
(d) A serving of creamy pasta, adorned with herbs and red pepper flakes, is placed on a white surface, with a striped cloth nearby
Figure 16: Application of Flash Diffusion to a DiT-based Diffusion model Pixart-α𝛼\alphaitalic_α.
Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Refer to caption
(d)
Figure 17: Generation variation for Flash Pixart with the prompt A yellow orchid trapped inside an empty bottle of wine.
Refer to caption
(a)
Refer to caption
(b)
Refer to caption
(c)
Refer to caption
(d)
Figure 18: Generation variation for Flash Pixart with the prompt An oil painting portrait of an elegant blond woman with a bowtie and hat.

B.3 Flash Inpainting

In Figure 19, we provide additional samples using the trained inpainting student model. We compare the samples generated by the student model using 4 NFEs to the teacher generations using 4 steps (i.e. 8 NFEs) and 20 steps (i.e. 40 NFEs).

Refer to caption
(a) Original image
Refer to caption
(b) Masked image
Refer to caption
(c) Teacher
(8 NFEs)
Refer to caption
(d) Teacher
(40 NFEs)
Refer to caption
(e) Ours
(4 NFEs)
Refer to caption
(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)
Refer to caption
(j)
Refer to caption
(k)
Refer to caption
(l)
Refer to caption
(m)
Refer to caption
(n)
Refer to caption
(o)
Refer to caption
(p)
Refer to caption
(q)
Refer to caption
(r)
Refer to caption
(s)
Refer to caption
(t)
Refer to caption
(u)
Refer to caption
(v)
Refer to caption
(w)
Refer to caption
(x)
Refer to caption
(y)
Refer to caption
(z)
Refer to caption
(aa)
Refer to caption
(ab)
Refer to caption
(ac)
Refer to caption
(ad)
Refer to caption
(ae)
Refer to caption
(af)
Refer to caption
(ag)
Refer to caption
(ah)
Refer to caption
(ai)
Figure 19: Application of Flash Diffusion to an in-house diffusion-based inpainting model. Best viewed zoomed in.

B.4 Flash Upscaler

In Figure 20, we provide additional samples using the trained super-resolution student model. As in the main paper, the student model is trained to output 1024x1024 images using 256x256 low-resolution images as conditioning. It is compared to the teacher generations using 4 steps (i.e. 8 NFEs) and 20 steps (i.e. 40 NFEs).

Refer to caption
(a) LR image
Refer to caption
(b) Teacher
(8 NFEs)
Refer to caption
(c) Teacher
(40 NFEs)
Refer to caption
(d) Ours
(4 NFEs)
Refer to caption
(e)
Refer to caption
(f)
Refer to caption
(g)
Refer to caption
(h)
Refer to caption
(i)
Refer to caption
(j)
Refer to caption
(k)
Refer to caption
(l)
Refer to caption
(m)
Refer to caption
(n)
Refer to caption
(o)
Refer to caption
(p)
Refer to caption
(q)
Refer to caption
(r)
Refer to caption
(s)
Refer to caption
(t)
Figure 20: Application of Flash Diffusion to an in-house diffusion-based super-resolution model. Best viewed zoomed in.