Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate DDPM step which is unused for now #150

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 3 additions & 52 deletions src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import Generator, Tensor, arange, device as Device, randn, tensor
from torch import Tensor, arange, device as Device

from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler

Expand Down Expand Up @@ -30,54 +30,5 @@ def _generate_timesteps(self) -> Tensor:
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Generate the next step in the diffusion process.

This method adjusts the input data using added noise and an estimate of the denoised data, based on the current
step in the diffusion process. This adjusted data forms the next step in the diffusion process.

1. It uses current and previous timesteps to calculate the current factor dictating the contribution of original
data and noise to the new step.
2. An estimate of the denoised data (`estimated_denoised_data`) is generated.
3. It calculates coefficients for the estimated denoised data and current data (`original_data_coeff` and
`current_data_coeff`) that balance their contribution to the denoised data for the next step.
4. It calculates the denoised data for the next step (`denoised_x`), which is a combination of the estimated
denoised data and current data, adjusted by their respective coefficients.
5. Noise is then added to `denoised_x`. The magnitude of noise is controlled by a calculated variance based on
the cumulative scaling factor and the current factor.

The output is the new data step for the next stage in the diffusion process.
"""
timestep, previous_timestep = (
self.timesteps[step],
(
self.timesteps[step + 1]
if step < len(self.timesteps) - 1
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
),
)
current_cumulative_factor, previous_cumulative_scale_factor = (
(self.scale_factors.cumprod(0))[timestep],
(
(self.scale_factors.cumprod(0))[previous_timestep]
if step < len(self.timesteps) - 1
else tensor(1, device=self.device)
),
)
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
1 - current_cumulative_factor
)
current_data_coeff = (
current_factor**0.5 * (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor)
)
denoised_x = original_data_coeff * estimated_denoised_data + current_data_coeff * x
if step < len(self.timesteps) - 1:
variance = (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) * (1 - current_factor)
denoised_x = denoised_x + (variance.clamp(min=1e-20) ** 0.5) * randn(
x.shape, device=x.device, dtype=x.dtype, generator=generator
)
return denoised_x
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
raise NotImplementedError
14 changes: 12 additions & 2 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
from warnings import warn

import pytest
from torch import Tensor, allclose, device as Device, randn
from torch import Tensor, allclose, device as Device, equal, randn

from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver


def test_ddpm_diffusers():
from diffusers import DDPMScheduler # type: ignore

diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
refiners_scheduler = DDPM(num_inference_steps=1000)

assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)


def test_dpm_solver_diffusers():
Expand Down