date | tags |
---|---|
2023-06-04 |
paper, diffusion, consistency, generative |
Yang Song, Prafulla Dhariwal, Mark Chen, Ilya Sutskever
arXiv Preprint
Year: 2023
Diffusion models, compared with single-step generative models such as GANs, VAEs or normalizing flows, require iterative sampling which limits real time applications but enable other zero shot capabilities (inpaining, colorization, etc). By choosing more or less steps in the reverse diffusion process, one can trade off quality vs compute. The objective of consistency models is to allow single-steps without sacrificing the mentioned capabilities.
The main idea of consistency models (referred as
The authors propose two methods to train consistency models: one based on distillation of a pretrained diffusion model, and other based on training the consistency model on insolation.
For this, let model
$$\mathcal{L}{CD}^N := (\theta, \theta^-; \phi) := \mathbb{E}[\lambda(t_n)d(f\theta(x_{t_{n+1}}, t_{n+1}), f_{\theta^-}(x_{t_{n}}^\phi, t_{n}))]$$
where:
-
$\mathbb{E}$ is taken over$n\sim\mathcal{U}(1, N-1)$ ,$x\sim p_\mathrm{data}$ , and$x_{t_{n+1}}\sim\mathcal{N}(x, t_{n+1}^2\mathcal{I})$ -
$\lambda(t_n)$ is a positive weighting function, which the authors end up setting it to 1.0 as it seems to work well in all cases -
$x_{t_{n}}^\phi$ is the estimation of$x_{t_{n}}$ given by the diffusion model -
$d(\cdot, \cdot)$ is a distance metric such as$\mathcal{L}_2$ or$\mathcal{L}_1$ . The authors also use [LPIPS]. -
$\theta$ denotes the parametes of the consistency model. -
$\theta^-$ is an exponential moving average of$\theta$ , similar to a target network in RL, that has shown empyrically to stabilize the training process:$\theta^- = \mathrm{sg} (\mu\theta^- + (1-\mu)\theta)$ (where sg comes from stop gradient). In this comparison,$\theta$ can be seen as the parameters online network.
The training algorithm is described below.
The authors recommend initializing this model with the weights of the pretrained diffusion model.
The paradigm is very similar to the distillation training, but in this case no pretrained model is required. Instead, the gradient of the ODE is approximated using Montecarlo, using the forward diffusion targets as follows.
$$\mathcal{L}{CT}^N = (\theta, \theta^-; \phi) := \mathbb{E}[\lambda(t_n)d(f\theta(x + t_{n+1}\cdot z, t_{n+1}), f_{\theta^-}(x + t_{n}\cdot z, t_{n}))]$$
Tricks of the trade:
- Progressively increase
$N$ during training. - Progressively decrease
$\mu$ during training.
The training algorithm is described below.
CT is independent of ODE solver, as it is trained using distillation targets.