Flash Diffusion:

Accelerating Any Conditional Diffusion Model for Few Steps Image Generation

1Jasper Research

Abstract

Showcase grid results

In this paper, we propose an efficient, fast, and versatile distillation method to accelerate the generation of pre-trained diffusion models: Flash Diffusion. The method reaches state-of-the-art performances in terms of FID and CLIP-Score for few steps image generation on the COCO2014 and COCO2017 datasets, while requiring only several GPU hours of training and fewer trainable parameters than existing methods. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swapping, super-resolution and using different backbones such as UNet-based denoisers (SD1.5, SDXL) or DiT (Pixart-α), as well as adapters. In all cases, the method allowed to reduce drastically the number of sampling steps while maintaining very high-quality image generation. The official implementation is available at https://github.com/gojasper/flash-diffusion

Method

Method Diagram

Fig 1: Flash Diffusion training method diagram

Our method aims to create a fast, reliable, and adaptable approach for various uses. The proposed method aims at training a student model to predict in a single step a denoised multiple-step teacher prediction of a corrupted input sample. Additionally, we sample timesteps from an adaptable distribution that shifts during training to help the student model target specific timesteps.

warm-up

Warm-up

phase-1

Phase 1

phase-2

Phase 2

phase-3

Phase 3

To further enhance the quality of the samples and since it proved very efficient in several works, we have also decided to incorporate an adversarial objective to train the student model to generate samples that are indistinguishable from the true data distribution. Therefore, we trained a discriminator to distinguish the generated samples from the real samples and apply it within the latent space to lower the computation requirements.

Finally, we also used a Distribution Matching Distillation Loss to ensure that the generated samples closely mirror the data distribution learned by the teacher model.

Results

Backbones

We also illustrate the ability of the proposed method to adapt to different backbones and trained using both UNet-based architecture and DiT models. Here are the results for FlashSDXL and FlashPixart

Results FlashSDXL
FlashSDXL
Results FlashPixart
FlashPixart

Conditionings

Our method can be used to train models for various use cases such as Inpainting, Upscaling, Face-Swapping and many more. Here are a couple of examples :

Results Inpainting
Inpainting
Results Upscaler
Upscaler
Results Swap
Face-Swap

BibTeX

@misc{chadebec2024flash,
          title={Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation},
          author={Clement Chadebec and Onur Tasar and Eyal Benaroche and Benjamin Aubin},
          year={2024},
          eprint={2406.02347},
          archivePrefix={arXiv},
          primaryClass={cs.CV}
          }