From b8226b91948aadd8adfb4c7781798349bd01e7e3 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Tue, 10 Oct 2023 17:20:04 -0700 Subject: [PATCH] Reorg OSS Diffusion Components to diffusion_labs folder (#480) Summary: Based on this [proposal](https://docs.google.com/document/d/1GtN2urD8PiRr1X4COzvbVbNE8LWRrcoYkAO4v2aogO8/edit) to reorganize diffusion components and models under a new `diffusion_labs`. This is the first in a stack of diffs. This one only reorganizes what's already been moved to OSS. This is primarily moving files with a couple of changes based on the proposal: - predictors.py is split into a separate file per predictor - adm is moved out of dalle2 to be it's own model adm_unet - Dalle2ImageTransform is moved to dalle2 out of transforms - schedule.py is renamed to discrete_guassian_schedule.py and an abstract DIffusionSchedule class was added - An abstract adapter class was added to be a generic type and enforce the `forward` signature - An abstract sampler class was added to be a generic type and enforce the `forward` and `generator` signature - A new dalle2_model unit test was added Differential Revision: D49790849 Pulled By: pbontrager fbshipit-source-id: 98fe40c2418dc542cced940dc761a9cd602b0398 --- .../diffusion_labs}/__init__.py | 0 .../test_adapter_cfguidance.py} | 4 +- .../dalle2 => diffusion_labs}/test_adm.py | 10 +- .../test_adm_blocks.py | 8 +- .../test_adm_crossattention.py | 4 +- .../test_common_util.py} | 2 +- tests/diffusion_labs/test_dalle2.py | 52 ++++++++ .../test_diffusion_losses.py} | 15 ++- .../test_diffusion_transform.py} | 18 +-- .../test_discrete_schedule.py} | 8 +- .../test_predictors.py | 13 +- .../test_sampler_ddim.py} | 12 +- .../test_sampler_ddpm.py} | 14 +-- .../dalle2/adm => diffusion_labs}/__init__.py | 0 .../models}/__init__.py | 0 .../models/adm_unet/__init__.py | 5 + .../models/adm_unet}/adm.py | 17 +-- .../models/adm_unet}/attention_block.py | 2 - .../models/adm_unet}/res_block.py | 6 +- .../diffusion_labs/models/dalle2/__init__.py | 5 + .../models/dalle2/dalle2_decoder.py | 25 ++-- .../models/dalle2/transforms.py} | 40 +------ .../diffusion_labs/modules/__init__.py | 5 + .../modules/adapters/__init__.py | 5 + .../modules/adapters/adapter.py | 43 +++++++ .../modules/adapters}/cfguidance.py | 5 +- .../diffusion_labs/modules/losses/__init__.py | 5 + .../modules/losses/diffusion_hybrid_loss.py | 63 ++++++++++ .../modules/losses/vlb_loss.py} | 58 +-------- .../diffusion_labs/predictors/__init__.py | 5 + .../predictors/noise_predictor.py | 55 +++++++++ .../diffusion_labs/predictors/predictor.py | 44 +++++++ .../predictors/target_predictor.py | 42 +++++++ .../diffusion_labs/samplers/__init__.py | 5 + .../samplers}/ddim.py | 34 +++--- .../samplers}/ddpm.py | 48 ++++---- .../diffusion_labs/samplers/sampler.py | 78 ++++++++++++ .../diffusion_labs/schedules/__init__.py | 5 + .../schedules/discrete_gaussian_schedule.py} | 26 ++-- .../diffusion_labs/schedules/schedule.py | 48 ++++++++ .../diffusion_labs/transforms/__init__.py | 5 + .../transforms/diffusion_transform.py | 40 +++++++ .../diffusion_labs/utils/__init__.py | 5 + .../utils/common.py} | 5 + .../modules/diffusion/predictors.py | 113 ------------------ 45 files changed, 654 insertions(+), 348 deletions(-) rename {torchmultimodal/models/dalle2 => tests/diffusion_labs}/__init__.py (100%) rename tests/{modules/diffusion/test_cfguidance.py => diffusion_labs/test_adapter_cfguidance.py} (96%) rename tests/{models/dalle2 => diffusion_labs}/test_adm.py (94%) rename tests/{models/dalle2 => diffusion_labs}/test_adm_blocks.py (94%) rename tests/{models/dalle2 => diffusion_labs}/test_adm_crossattention.py (92%) rename tests/{utils/test_diffusion_utils.py => diffusion_labs/test_common_util.py} (87%) create mode 100644 tests/diffusion_labs/test_dalle2.py rename tests/{modules/losses/test_diffusion_loss.py => diffusion_labs/test_diffusion_losses.py} (85%) rename tests/{transforms/test_diffusion_transforms.py => diffusion_labs/test_diffusion_transform.py} (52%) rename tests/{modules/diffusion/test_schedule.py => diffusion_labs/test_discrete_schedule.py} (93%) rename tests/{modules/diffusion => diffusion_labs}/test_predictors.py (77%) rename tests/{modules/diffusion/test_ddim.py => diffusion_labs/test_sampler_ddim.py} (83%) rename tests/{modules/diffusion/test_ddpm.py => diffusion_labs/test_sampler_ddpm.py} (88%) rename torchmultimodal/{models/dalle2/adm => diffusion_labs}/__init__.py (100%) rename torchmultimodal/{modules/diffusion => diffusion_labs/models}/__init__.py (100%) create mode 100644 torchmultimodal/diffusion_labs/models/adm_unet/__init__.py rename torchmultimodal/{models/dalle2/adm => diffusion_labs/models/adm_unet}/adm.py (98%) rename torchmultimodal/{models/dalle2/adm => diffusion_labs/models/adm_unet}/attention_block.py (99%) rename torchmultimodal/{models/dalle2/adm => diffusion_labs/models/adm_unet}/res_block.py (98%) create mode 100644 torchmultimodal/diffusion_labs/models/dalle2/__init__.py rename torchmultimodal/{ => diffusion_labs}/models/dalle2/dalle2_decoder.py (87%) rename torchmultimodal/{transforms/diffusion_transforms.py => diffusion_labs/models/dalle2/transforms.py} (51%) create mode 100644 torchmultimodal/diffusion_labs/modules/__init__.py create mode 100644 torchmultimodal/diffusion_labs/modules/adapters/__init__.py create mode 100644 torchmultimodal/diffusion_labs/modules/adapters/adapter.py rename torchmultimodal/{modules/diffusion => diffusion_labs/modules/adapters}/cfguidance.py (97%) create mode 100644 torchmultimodal/diffusion_labs/modules/losses/__init__.py create mode 100644 torchmultimodal/diffusion_labs/modules/losses/diffusion_hybrid_loss.py rename torchmultimodal/{modules/losses/diffusion.py => diffusion_labs/modules/losses/vlb_loss.py} (64%) create mode 100644 torchmultimodal/diffusion_labs/predictors/__init__.py create mode 100644 torchmultimodal/diffusion_labs/predictors/noise_predictor.py create mode 100644 torchmultimodal/diffusion_labs/predictors/predictor.py create mode 100644 torchmultimodal/diffusion_labs/predictors/target_predictor.py create mode 100644 torchmultimodal/diffusion_labs/samplers/__init__.py rename torchmultimodal/{modules/diffusion => diffusion_labs/samplers}/ddim.py (83%) rename torchmultimodal/{modules/diffusion => diffusion_labs/samplers}/ddpm.py (81%) create mode 100644 torchmultimodal/diffusion_labs/samplers/sampler.py create mode 100644 torchmultimodal/diffusion_labs/schedules/__init__.py rename torchmultimodal/{modules/diffusion/schedules.py => diffusion_labs/schedules/discrete_gaussian_schedule.py} (92%) create mode 100644 torchmultimodal/diffusion_labs/schedules/schedule.py create mode 100644 torchmultimodal/diffusion_labs/transforms/__init__.py create mode 100644 torchmultimodal/diffusion_labs/transforms/diffusion_transform.py create mode 100644 torchmultimodal/diffusion_labs/utils/__init__.py rename torchmultimodal/{utils/diffusion_utils.py => diffusion_labs/utils/common.py} (89%) delete mode 100644 torchmultimodal/modules/diffusion/predictors.py diff --git a/torchmultimodal/models/dalle2/__init__.py b/tests/diffusion_labs/__init__.py similarity index 100% rename from torchmultimodal/models/dalle2/__init__.py rename to tests/diffusion_labs/__init__.py diff --git a/tests/modules/diffusion/test_cfguidance.py b/tests/diffusion_labs/test_adapter_cfguidance.py similarity index 96% rename from tests/modules/diffusion/test_cfguidance.py rename to tests/diffusion_labs/test_adapter_cfguidance.py index dcb402b9..d0dace3d 100644 --- a/tests/modules/diffusion/test_cfguidance.py +++ b/tests/diffusion_labs/test_adapter_cfguidance.py @@ -9,8 +9,8 @@ import torch from tests.test_utils import assert_expected, set_rng_seed from torch import nn -from torchmultimodal.modules.diffusion.cfguidance import CFGuidance -from torchmultimodal.utils.diffusion_utils import DiffusionOutput +from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput @pytest.fixture(autouse=True) diff --git a/tests/models/dalle2/test_adm.py b/tests/diffusion_labs/test_adm.py similarity index 94% rename from tests/models/dalle2/test_adm.py rename to tests/diffusion_labs/test_adm.py index 0429e5c0..4a063c24 100644 --- a/tests/models/dalle2/test_adm.py +++ b/tests/diffusion_labs/test_adm.py @@ -10,9 +10,11 @@ import torch from tests.test_utils import assert_expected, set_rng_seed from torch import nn -from torchmultimodal.models.dalle2.adm.adm import ADM, ADMStack, ADMUNet -from torchmultimodal.models.dalle2.adm.attention_block import ADMAttentionBlock -from torchmultimodal.models.dalle2.adm.res_block import ADMResBlock +from torchmultimodal.diffusion_labs.models.adm_unet.adm import ADM, ADMStack, ADMUNet +from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( + ADMAttentionBlock, +) +from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ADMResBlock @pytest.fixture(autouse=True) @@ -116,7 +118,7 @@ def test_predict_variance_value_incorrect_channel_dim_error( # All expected values come after first testing the ADMUNet has the exact output -# as the corresponding UNet class in d2go, then simply forward passing +# as the corresponding author UNet implementation, then simply forward passing # ADMUNet with params, random seed, and initialization order in this file. class TestADMUNet: @pytest.fixture diff --git a/tests/models/dalle2/test_adm_blocks.py b/tests/diffusion_labs/test_adm_blocks.py similarity index 94% rename from tests/models/dalle2/test_adm_blocks.py rename to tests/diffusion_labs/test_adm_blocks.py index 9bca0680..2b878955 100644 --- a/tests/models/dalle2/test_adm_blocks.py +++ b/tests/diffusion_labs/test_adm_blocks.py @@ -10,8 +10,10 @@ import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.models.dalle2.adm.attention_block import ADMAttentionBlock -from torchmultimodal.models.dalle2.adm.res_block import ( +from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( + ADMAttentionBlock, +) +from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, @@ -47,7 +49,7 @@ def t(params): # All expected values come after first testing the ADMResBlock has the exact output -# as the corresponding residual block class in d2go, then simply forward passing +# as the corresponding residual block class from ADM authors, then simply forward passing # ADMResBlock with params, random seed, and initialization order in this file. class TestADMResBlock: @pytest.fixture diff --git a/tests/models/dalle2/test_adm_crossattention.py b/tests/diffusion_labs/test_adm_crossattention.py similarity index 92% rename from tests/models/dalle2/test_adm_crossattention.py rename to tests/diffusion_labs/test_adm_crossattention.py index 5328b881..fd61c778 100644 --- a/tests/models/dalle2/test_adm_crossattention.py +++ b/tests/diffusion_labs/test_adm_crossattention.py @@ -8,7 +8,7 @@ import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.models.dalle2.adm.attention_block import ( +from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( adm_attention, ADMCrossAttention, ) @@ -44,7 +44,7 @@ def c(params): # All expected values come after first testing that ADMCrossAttention has -# the exact output as the corresponding QKVAttention class in d2go, then simply forward passing +# the exact output as the corresponding QKVAttention class in ADM, then simply forward passing # ADMCrossAttention with params, random seed, and initialization order in this file. class TestADMCrossAttention: @pytest.fixture diff --git a/tests/utils/test_diffusion_utils.py b/tests/diffusion_labs/test_common_util.py similarity index 87% rename from tests/utils/test_diffusion_utils.py rename to tests/diffusion_labs/test_common_util.py index f0466d74..406196df 100644 --- a/tests/utils/test_diffusion_utils.py +++ b/tests/diffusion_labs/test_common_util.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. from PIL import Image -from torchmultimodal.utils.diffusion_utils import cascaded_resize +from torchmultimodal.diffusion_labs.utils.common import cascaded_resize def test_cascaded_resize(): diff --git a/tests/diffusion_labs/test_dalle2.py b/tests/diffusion_labs/test_dalle2.py new file mode 100644 index 00000000..75cba0cc --- /dev/null +++ b/tests/diffusion_labs/test_dalle2.py @@ -0,0 +1,52 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from tests.test_utils import assert_expected, set_rng_seed +from torchmultimodal.diffusion_labs.models.dalle2.dalle2_decoder import dalle2_decoder +from torchmultimodal.diffusion_labs.models.dalle2.transforms import Dalle2ImageTransform + + +def test_dalle2_model(): + set_rng_seed(4) + model = dalle2_decoder( + timesteps=1, + time_embed_dim=1, + cond_embed_dim=1, + clip_embed_dim=1, + clip_embed_name="clip_image", + predict_variance_value=True, + image_channels=1, + depth=32, + num_resize=1, + num_res_per_layer=1, + use_cf_guidance=True, + clip_image_guidance_dropout=0.1, + guidance_strength=7.0, + learn_null_emb=True, + ) + model.eval() + x = torch.randn(1, 1, 4, 4) + c = torch.ones((1, 1)) + with torch.no_grad(): + actual = model(x, conditional_inputs={"clip_image": c}).mean() + expected = torch.as_tensor(0.12768) + assert_expected(actual, expected, rtol=0, atol=1e-4) + + +def test_dalle2_image_transform(): + img_size = 5 + transform = Dalle2ImageTransform(image_size=img_size, image_min=-1, image_max=1) + image = Image.new("RGB", size=(20, 20), color=(128, 0, 0)) + actual = transform(image).sum() + normalized128 = 128 / 255 * 2 - 1 + normalized0 = -1 + expected = torch.tensor( + normalized128 * img_size**2 + 2 * normalized0 * img_size**2 + ) + assert_expected(actual, expected, rtol=0, atol=1e-4) diff --git a/tests/modules/losses/test_diffusion_loss.py b/tests/diffusion_labs/test_diffusion_losses.py similarity index 85% rename from tests/modules/losses/test_diffusion_loss.py rename to tests/diffusion_labs/test_diffusion_losses.py index 980968b7..a0e1e12f 100644 --- a/tests/modules/losses/test_diffusion_loss.py +++ b/tests/diffusion_labs/test_diffusion_losses.py @@ -9,12 +9,15 @@ import torch import torch.nn as nn from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.modules.diffusion.schedules import ( - DiffusionSchedule, +from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import ( + DiffusionHybridLoss, +) +from torchmultimodal.diffusion_labs.modules.losses.vlb_loss import VLBLoss +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, linear_beta_schedule, ) -from torchmultimodal.modules.losses.diffusion import DiffusionHybridLoss, VLBLoss -from torchmultimodal.utils.diffusion_utils import DiffusionOutput +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput @pytest.fixture(autouse=True) @@ -24,7 +27,7 @@ def set_seed(): @pytest.fixture def schedule(): - return DiffusionSchedule(linear_beta_schedule(1000)) + return DiscreteGaussianSchedule(linear_beta_schedule(1000)) @pytest.fixture @@ -44,7 +47,7 @@ def target(): # All expected values come after first testing the HybridLoss has the exact output -# as the corresponding p_losses in D2Go Guassian Diffusion +# as the corresponding p_losses in Guassian Diffusion class TestDiffusionHybridLoss: @pytest.fixture def loss(self, schedule): diff --git a/tests/transforms/test_diffusion_transforms.py b/tests/diffusion_labs/test_diffusion_transform.py similarity index 52% rename from tests/transforms/test_diffusion_transforms.py rename to tests/diffusion_labs/test_diffusion_transform.py index 43a42a2f..55bbbb66 100644 --- a/tests/transforms/test_diffusion_transforms.py +++ b/tests/diffusion_labs/test_diffusion_transform.py @@ -6,10 +6,7 @@ # LICENSE file in the root directory of this source tree. import torch -from PIL import Image -from tests.test_utils import assert_expected -from torchmultimodal.transforms.diffusion_transforms import ( - Dalle2ImageTransform, +from torchmultimodal.diffusion_labs.transforms.diffusion_transform import ( RandomDiffusionSteps, ) @@ -30,16 +27,3 @@ def test_random_diffusion_steps(): actual = len(transform(torch.ones(1))) expected = 4 assert actual == expected, "Transform not returning correct keys" - - -def test_dalle_image_transform(): - img_size = 5 - transform = Dalle2ImageTransform(image_size=img_size, image_min=-1, image_max=1) - image = Image.new("RGB", size=(20, 20), color=(128, 0, 0)) - actual = transform(image).sum() - normalized128 = 128 / 255 * 2 - 1 - normalized0 = -1 - expected = torch.tensor( - normalized128 * img_size**2 + 2 * normalized0 * img_size**2 - ) - assert_expected(actual, expected, rtol=0, atol=1e-4) diff --git a/tests/modules/diffusion/test_schedule.py b/tests/diffusion_labs/test_discrete_schedule.py similarity index 93% rename from tests/modules/diffusion/test_schedule.py rename to tests/diffusion_labs/test_discrete_schedule.py index 6f5f49b0..3a4312d3 100644 --- a/tests/modules/diffusion/test_schedule.py +++ b/tests/diffusion_labs/test_discrete_schedule.py @@ -8,9 +8,9 @@ import pytest import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.modules.diffusion.schedules import ( +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( cosine_beta_schedule, - DiffusionSchedule, + DiscreteGaussianSchedule, linear_beta_schedule, quadratic_beta_schedule, sigmoid_beta_schedule, @@ -23,11 +23,11 @@ def set_seed(): # All expected values come after first testing the Schedule has the exact output -# as the corresponding q methods from GaussianDiffusion in D2Go +# as the corresponding q methods from GaussianDiffusion class TestDiffusionSchedule: @pytest.fixture def module(self): - schedule = DiffusionSchedule(linear_beta_schedule(1000)) + schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) return schedule @pytest.fixture diff --git a/tests/modules/diffusion/test_predictors.py b/tests/diffusion_labs/test_predictors.py similarity index 77% rename from tests/modules/diffusion/test_predictors.py rename to tests/diffusion_labs/test_predictors.py index 4ffc5129..001078f0 100644 --- a/tests/modules/diffusion/test_predictors.py +++ b/tests/diffusion_labs/test_predictors.py @@ -8,9 +8,10 @@ import pytest import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.modules.diffusion.predictors import NoisePredictor, TargetPredictor -from torchmultimodal.modules.diffusion.schedules import ( - DiffusionSchedule, +from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor +from torchmultimodal.diffusion_labs.predictors.target_predictor import TargetPredictor +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, linear_beta_schedule, ) @@ -30,11 +31,11 @@ def input(): # All expected values come after first testing the Schedule has the exact output -# as the corresponding q methods from GaussianDiffusion in D2Go +# as the corresponding q methods from GaussianDiffusion class TestNoisePredictor: @pytest.fixture def module(self): - schedule = DiffusionSchedule(linear_beta_schedule(1000)) + schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) predictor = NoisePredictor(schedule, None) return predictor @@ -52,7 +53,7 @@ def test_predict_noise(self, module, input): class TestTargetPredictor: @pytest.fixture def module(self): - schedule = DiffusionSchedule(linear_beta_schedule(1000)) + schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) predictor = TargetPredictor(schedule, None) return predictor diff --git a/tests/modules/diffusion/test_ddim.py b/tests/diffusion_labs/test_sampler_ddim.py similarity index 83% rename from tests/modules/diffusion/test_ddim.py rename to tests/diffusion_labs/test_sampler_ddim.py index 097298a5..c0bda474 100644 --- a/tests/modules/diffusion/test_ddim.py +++ b/tests/diffusion_labs/test_sampler_ddim.py @@ -9,13 +9,13 @@ import torch import torch.nn as nn from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.modules.diffusion.ddim import DDIModule -from torchmultimodal.modules.diffusion.predictors import NoisePredictor -from torchmultimodal.modules.diffusion.schedules import ( - DiffusionSchedule, +from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor +from torchmultimodal.diffusion_labs.samplers.ddim import DDIModule +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, linear_beta_schedule, ) -from torchmultimodal.utils.diffusion_utils import DiffusionOutput +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput class DummyUNet(nn.Module): @@ -38,7 +38,7 @@ class TestDDIModule: @pytest.fixture def module(self): model = DummyUNet(True) - schedule = DiffusionSchedule(linear_beta_schedule(1000)) + schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) predictor = NoisePredictor(schedule) model = DDIModule(model, schedule, predictor) return model diff --git a/tests/modules/diffusion/test_ddpm.py b/tests/diffusion_labs/test_sampler_ddpm.py similarity index 88% rename from tests/modules/diffusion/test_ddpm.py rename to tests/diffusion_labs/test_sampler_ddpm.py index c94e2c88..f4a2098a 100644 --- a/tests/modules/diffusion/test_ddpm.py +++ b/tests/diffusion_labs/test_sampler_ddpm.py @@ -10,13 +10,13 @@ import torch.nn as nn import torch.nn.functional as F from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.modules.diffusion.ddpm import DDPModule -from torchmultimodal.modules.diffusion.predictors import NoisePredictor -from torchmultimodal.modules.diffusion.schedules import ( - DiffusionSchedule, +from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor +from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, linear_beta_schedule, ) -from torchmultimodal.utils.diffusion_utils import DiffusionOutput +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput class DummyUNet(nn.Module): @@ -36,13 +36,13 @@ def forward(self, x, t, c): # All expected values come after first testing the Schedule has the exact output -# as the corresponding p methods from GaussianDiffusion in D2Go +# as the corresponding p methods from GaussianDiffusion class TestDDPModule: @pytest.fixture def module(self): set_rng_seed(4) model = DummyUNet(True) - schedule = DiffusionSchedule(linear_beta_schedule(1000)) + schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) predictor = NoisePredictor(schedule) eval_steps = torch.arange(0, 1000, 50) model = DDPModule(model, schedule, predictor, eval_steps) diff --git a/torchmultimodal/models/dalle2/adm/__init__.py b/torchmultimodal/diffusion_labs/__init__.py similarity index 100% rename from torchmultimodal/models/dalle2/adm/__init__.py rename to torchmultimodal/diffusion_labs/__init__.py diff --git a/torchmultimodal/modules/diffusion/__init__.py b/torchmultimodal/diffusion_labs/models/__init__.py similarity index 100% rename from torchmultimodal/modules/diffusion/__init__.py rename to torchmultimodal/diffusion_labs/models/__init__.py diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/__init__.py b/torchmultimodal/diffusion_labs/models/adm_unet/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/adm_unet/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/dalle2/adm/adm.py b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py similarity index 98% rename from torchmultimodal/models/dalle2/adm/adm.py rename to torchmultimodal/diffusion_labs/models/adm_unet/adm.py index dbcfb6aa..b681c10e 100644 --- a/torchmultimodal/models/dalle2/adm/adm.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py @@ -8,21 +8,21 @@ import torch from torch import nn, Tensor -from torchmultimodal.models.dalle2.adm.attention_block import ( +from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( adm_attn_block, ADMAttentionBlock, ) -from torchmultimodal.models.dalle2.adm.res_block import ( +from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, ADMResBlock, ) +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm from torchmultimodal.modules.layers.position_embedding import ( SinusoidalPositionEmbeddings, ) -from torchmultimodal.utils.diffusion_utils import DiffusionOutput class ADM(nn.Module): @@ -35,7 +35,8 @@ class ADM(nn.Module): The UNet can be swapped with any module. The UNet used in the paper is constructed in the ADMUNet class. - Code ref: https://fburl.com/code/b6lk39ym + Code ref: + https://github.com/lucidrains/DALLE2-pytorch/blob/c6c3882dc165914413ca97176b3a0103af1d7048/dalle2_pytorch/dalle2_pytorch.py#L1856 Attributes: unet (nn.Module): model that will process three inputs: the input noised image, embedded timestep, and @@ -383,8 +384,6 @@ def forward( class ADMStack(nn.ModuleList): """A container that acts as a ModuleList of ADM blocks and handles passing conditional inputs correctly to the children ADMResBlocks and ADMAttentionBlocks. - - Code ref: https://fburl.com/code/x9nuqaov """ def forward( @@ -448,7 +447,7 @@ def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.ModuleList: ) -def dalle2_adm( +def adm_unet( *, # ADM args time_embed_dim: int = 512, @@ -483,9 +482,7 @@ def dalle2_adm( num_resize (int): number of times resolution will be scaled num_res_per_layer (int): number of residual blocks per resolution """ - # # Construct UNet - # in_channels = image_channels # If predicting variance, double the channel dim of UNet output and use those values as variance @@ -506,9 +503,7 @@ def dalle2_adm( out_channels=out_channels, ) - # # Construct ADM (UNet + timestep/conditional projections) - # time_embed = nn.Sequential( SinusoidalPositionEmbeddings(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, cond_embed_dim), diff --git a/torchmultimodal/models/dalle2/adm/attention_block.py b/torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py similarity index 99% rename from torchmultimodal/models/dalle2/adm/attention_block.py rename to torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py index 0c2c2fa9..1a7ea4b9 100644 --- a/torchmultimodal/models/dalle2/adm/attention_block.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py @@ -22,8 +22,6 @@ class ADMAttentionBlock(nn.Module): Code ref: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/unet.py#L259 - Defaults taken from https://fburl.com/code/tiryi2f9. - Attributes: num_channels (int): channel dim expected in input, determines embedding dim of q, k, v in attention module. Needs to be divisible by norm_groups. diff --git a/torchmultimodal/models/dalle2/adm/res_block.py b/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py similarity index 98% rename from torchmultimodal/models/dalle2/adm/res_block.py rename to torchmultimodal/diffusion_labs/models/adm_unet/res_block.py index dc611dc2..f97426ac 100644 --- a/torchmultimodal/models/dalle2/adm/res_block.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py @@ -21,8 +21,6 @@ class ADMResBlock(nn.Module): Code ref: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/unet.py#L143 - Defaults taken from https://fburl.com/code/tiryi2f9. - Attributes: in_channels (int): num channels expected in input. Needs to be divisible by norm_groups. out_channels (int): num channels desired in output. Needs to be divisible by norm_groups. @@ -43,8 +41,8 @@ class ADMResBlock(nn.Module): norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. Args: - x (Tensor): input Tensor of shape [B x C x H x W] - conditional_embedding (Tensor): conditioning embedding vector of shape [B x C]. + x (Tensor): input Tensor of shape [b, c, h, w] + conditional_embedding (Tensor): conditioning embedding vector of shape [b, c]. Raises: TypeError: When skip_conv is not defined and in_channels != out_channels. diff --git a/torchmultimodal/diffusion_labs/models/dalle2/__init__.py b/torchmultimodal/diffusion_labs/models/dalle2/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/dalle2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/dalle2/dalle2_decoder.py b/torchmultimodal/diffusion_labs/models/dalle2/dalle2_decoder.py similarity index 87% rename from torchmultimodal/models/dalle2/dalle2_decoder.py rename to torchmultimodal/diffusion_labs/models/dalle2/dalle2_decoder.py index 899b47da..6f4bc073 100644 --- a/torchmultimodal/models/dalle2/dalle2_decoder.py +++ b/torchmultimodal/diffusion_labs/models/dalle2/dalle2_decoder.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. import torch -from torchmultimodal.models.dalle2.adm.adm import dalle2_adm -from torchmultimodal.modules.diffusion.cfguidance import CFGuidance -from torchmultimodal.modules.diffusion.ddpm import DDPModule -from torchmultimodal.modules.diffusion.predictors import NoisePredictor -from torchmultimodal.modules.diffusion.schedules import ( +from torchmultimodal.diffusion_labs.models.adm_unet.adm import adm_unet +from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance +from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor +from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule +from torchmultimodal.diffusion_labs.samplers.sampler import Sampler +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( cosine_beta_schedule, - DiffusionSchedule, + DiscreteGaussianSchedule, ) @@ -34,7 +35,7 @@ def dalle2_decoder( clip_image_guidance_dropout: float = 0.1, guidance_strength: float = 7.0, learn_null_emb: bool = True, -) -> DDPModule: +) -> Sampler: """Constructs primary DALLE-2 diffusion decoder without upsampling. Consists of an ADM UNet diffusion model conditioned on CLIP image embeddings. Uses DDPM to generate @@ -71,10 +72,8 @@ def dalle2_decoder( learn_null_emb (bool): If False, then unconditional embeddings are set to zero and are not trainable If True, then unconditional embeddings are set to random and are trainable. Defaults to True. """ - # # Construct UNet - # - diffusion_model = dalle2_adm( + diffusion_model = adm_unet( time_embed_dim=time_embed_dim, cond_embed_dim=cond_embed_dim, clip_embed_dim=clip_embed_dim, @@ -86,9 +85,7 @@ def dalle2_decoder( num_res_per_layer=num_res_per_layer, ) - # # Construct CFGuidance wrapper around ADM model - # if use_cf_guidance: diffusion_model = CFGuidance( model=diffusion_model, @@ -98,11 +95,9 @@ def dalle2_decoder( learn_null_emb=learn_null_emb, ) - # # Construct DDPM decoder - # eval_steps = torch.linspace(0, timesteps - 1, timesteps // 4, dtype=torch.int) - schedule = DiffusionSchedule(cosine_beta_schedule(timesteps)) + schedule = DiscreteGaussianSchedule(cosine_beta_schedule(timesteps)) predictor = NoisePredictor(schedule, lambda x: x.clamp(-1, 1)) model = DDPModule( model=diffusion_model, diff --git a/torchmultimodal/transforms/diffusion_transforms.py b/torchmultimodal/diffusion_labs/models/dalle2/transforms.py similarity index 51% rename from torchmultimodal/transforms/diffusion_transforms.py rename to torchmultimodal/diffusion_labs/models/dalle2/transforms.py index adac5025..1f731b4c 100644 --- a/torchmultimodal/transforms/diffusion_transforms.py +++ b/torchmultimodal/diffusion_labs/models/dalle2/transforms.py @@ -5,20 +5,14 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import List, Tuple, Union +from typing import List, Union import torch import torchvision.transforms as tv from PIL.Image import Image from torch import nn, Tensor -from torchmultimodal.modules.diffusion.schedules import DiffusionSchedule -from torchmultimodal.utils.diffusion_utils import cascaded_resize - - -def normalize(x: Tensor, image_min: int, image_max: int) -> Tensor: - # Normalize image values between min and max - return (image_max - image_min) * x + image_min +from torchmultimodal.diffusion_labs.utils.common import cascaded_resize, normalize class Dalle2ImageTransform(nn.Module): @@ -57,33 +51,3 @@ def forward(self, image: Union[List[Image], Image]) -> Tensor: # pyre-ignore image_result = torch.stack([self.image_transform(x) for x in image]) return image_result - - -class RandomDiffusionSteps(nn.Module): - """Data Transform to randomly sample noised data from the diffusion schedule. - During diffusion training, random diffusion steps are sampled per model update. - This transform samples steps and returns the steps (t), seed noise, and transformed - data at time t (xt). - - Attributes: - schedule (DiffusionSchedule): defines diffusion of noise through time - batched (bool): if True, transform expects a batched input - - Args: - x (Tensor): data representing x0, artifact being learned. The 0 represents zero diffusion steps. - """ - - def __init__(self, schedule: DiffusionSchedule, batched: bool = True): - super().__init__() - self.schedule = schedule - self.batched = batched - - def __call__(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - if not self.batched: - t = self.schedule.sample_steps(x.unsqueeze(0)) - t = t.squeeze(0) - else: - t = self.schedule.sample_steps(x) - noise = self.schedule.sample_noise(x) - xt = self.schedule.q_sample(x, noise, t) - return x, xt, noise, t diff --git a/torchmultimodal/diffusion_labs/modules/__init__.py b/torchmultimodal/diffusion_labs/modules/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/modules/adapters/__init__.py b/torchmultimodal/diffusion_labs/modules/adapters/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/modules/adapters/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/modules/adapters/adapter.py b/torchmultimodal/diffusion_labs/modules/adapters/adapter.py new file mode 100644 index 00000000..3635cd30 --- /dev/null +++ b/torchmultimodal/diffusion_labs/modules/adapters/adapter.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from typing import Dict, Optional, Protocol, runtime_checkable + +from torch import Tensor + +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput + + +@runtime_checkable +class Adapter(Protocol): + """Adapter modules act as wrappers on the underlying denoising model. These are flexible + and allow the base model to be augmented to perform common diffusion tasks. Since Adapters + share the same signature as the underlying model and the Sampler class, multiple adapters + can be stacked together. + + Example: + denoising_model = Unet(...) + augmented_model = Adapter2(Adapter1(denoising_model)) + model = DDIM(augmented_model, ...) + + """ + + @abstractmethod + def forward( + self, + x: Tensor, + timestep: Tensor, + conditional_inputs: Optional[Dict[str, Tensor]] = None, + ) -> DiffusionOutput: + """Model forward pass + + Args: + x (Tensor): input Tensor of shape [b, in_channels, ...] + timestep (Tensor): diffusion step + conditional_inputs (Dict[str, Tensor]): conditional embedding as a dictionary. + Conditional embeddings must have at least 2 dimensions. + """ diff --git a/torchmultimodal/modules/diffusion/cfguidance.py b/torchmultimodal/diffusion_labs/modules/adapters/cfguidance.py similarity index 97% rename from torchmultimodal/modules/diffusion/cfguidance.py rename to torchmultimodal/diffusion_labs/modules/adapters/cfguidance.py index 837d4d0c..f1f46401 100644 --- a/torchmultimodal/modules/diffusion/cfguidance.py +++ b/torchmultimodal/diffusion_labs/modules/adapters/cfguidance.py @@ -9,10 +9,11 @@ import torch from torch import nn, Tensor -from torchmultimodal.utils.diffusion_utils import DiffusionOutput +from torchmultimodal.diffusion_labs.modules.adapters.adapter import Adapter +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput -class CFGuidance(nn.Module): +class CFGuidance(nn.Module, Adapter): """ Classifier free guidance gives diffusion models the ability to sample from a conditional distribution, while maintaining a healthy ratio between exploitation (i.e. correlation diff --git a/torchmultimodal/diffusion_labs/modules/losses/__init__.py b/torchmultimodal/diffusion_labs/modules/losses/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/modules/losses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/modules/losses/diffusion_hybrid_loss.py b/torchmultimodal/diffusion_labs/modules/losses/diffusion_hybrid_loss.py new file mode 100644 index 00000000..56e28ec3 --- /dev/null +++ b/torchmultimodal/diffusion_labs/modules/losses/diffusion_hybrid_loss.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn, Tensor +from torchmultimodal.diffusion_labs.modules.losses.vlb_loss import VLBLoss +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) + + +class DiffusionHybridLoss(nn.Module): + """ + Combines both simple loss (typically MSE) and VLB loss weighted by lambda, as described in Eq. 16 of + "Improved Denoising Diffusion Probabilistic Models" (https://arxiv.org/abs/2102.09672). + VLB loss is only used to train the model learned variance. + + Attributes: + schedule (DiffusionSchedule): defines diffusion of noise through time + simple_loss (nn.Module): loss function computed on prediction of diffusion model and desired target + (typically noise). Default is nn.MSELoss. + lmbda (float): lambda weight for vlb loss. Default is 0.001. + + Args: + input (Tensor): prediction of diffusion model of shape [b, c, ...] + target (Tensor): desired target of shape [b, c, ...] + mean (Tensor): predicted mean of posterior/xt of shape [b, c, ...] + log_variance (Tensor): predicted log variance of posterior/xt of shape [b, c, ...] + x0 (Tensor): data sample of shape [b, c,...] + xt (Tensor): noised data sample from diffusion process of shape [b, c, ...] + t (Tensor): diffusion timesteps of shape [b, ] + + """ + + def __init__( + self, + schedule: DiscreteGaussianSchedule, + simple_loss: nn.Module = nn.MSELoss(), + lmbda: float = 0.001, + ): + super().__init__() + self.simple_loss = simple_loss + self.vlb_loss = VLBLoss(schedule) + self.lmbda = lmbda + + def forward( + self, + input: Tensor, + target: Tensor, + mean: Tensor, + log_variance: Tensor, + x0: Tensor, + xt: Tensor, + t: Tensor, + ) -> Tensor: + # Detach mean as stop gradient for vlb loss + # Weight the vlb loss smaller, for stability, when training in a hybrid setting using + # another criterion to train the predictor as in the paper (recommended 0.001) + return self.simple_loss(input, target) + self.lmbda * self.vlb_loss( + mean.detach(), log_variance, x0, xt, t + ) diff --git a/torchmultimodal/modules/losses/diffusion.py b/torchmultimodal/diffusion_labs/modules/losses/vlb_loss.py similarity index 64% rename from torchmultimodal/modules/losses/diffusion.py rename to torchmultimodal/diffusion_labs/modules/losses/vlb_loss.py index 6ea93fb6..c411d45f 100644 --- a/torchmultimodal/modules/losses/diffusion.py +++ b/torchmultimodal/diffusion_labs/modules/losses/vlb_loss.py @@ -8,59 +8,9 @@ import torch from torch import nn, Tensor -from torchmultimodal.modules.diffusion.schedules import DiffusionSchedule - - -class DiffusionHybridLoss(nn.Module): - """ - Combines both simple loss (typically MSE) and VLB loss weighted by lambda, as described in Eq. 16 of - "Improved Denoising Diffusion Probabilistic Models" (https://arxiv.org/abs/2102.09672). - VLB loss is only used to train the model learned variance. - - Attributes: - schedule (DiffusionSchedule): defines diffusion of noise through time - simple_loss (nn.Module): loss function computed on prediction of diffusion model and desired target - (typically noise). Default is nn.MSELoss. - lmbda (float): lambda weight for vlb loss. Default is 0.001. - - Args: - input (Tensor): prediction of diffusion model of shape [b, c, ...] - target (Tensor): desired target of shape [b, c, ...] - mean (Tensor): predicted mean of posterior/xt of shape [b, c, ...] - log_variance (Tensor): predicted log variance of posterior/xt of shape [b, c, ...] - x0 (Tensor): data sample of shape [b, c,...] - xt (Tensor): noised data sample from diffusion process of shape [b, c, ...] - t (Tensor): diffusion timesteps of shape [b, ] - - """ - - def __init__( - self, - schedule: DiffusionSchedule, - simple_loss: nn.Module = nn.MSELoss(), - lmbda: float = 0.001, - ): - super().__init__() - self.simple_loss = simple_loss - self.vlb_loss = VLBLoss(schedule) - self.lmbda = lmbda - - def forward( - self, - input: Tensor, - target: Tensor, - mean: Tensor, - log_variance: Tensor, - x0: Tensor, - xt: Tensor, - t: Tensor, - ) -> Tensor: - # Detach mean as stop gradient for vlb loss - # Weight the vlb loss smaller, for stability, when training in a hybrid setting using - # another criterion to train the predictor as in the paper (recommended 0.001) - return self.simple_loss(input, target) + self.lmbda * self.vlb_loss( - mean.detach(), log_variance, x0, xt, t - ) +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) class VLBLoss(nn.Module): @@ -88,7 +38,7 @@ class VLBLoss(nn.Module): """ - def __init__(self, schedule: DiffusionSchedule): + def __init__(self, schedule: DiscreteGaussianSchedule): super().__init__() self.schedule = schedule diff --git a/torchmultimodal/diffusion_labs/predictors/__init__.py b/torchmultimodal/diffusion_labs/predictors/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/predictors/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/predictors/noise_predictor.py b/torchmultimodal/diffusion_labs/predictors/noise_predictor.py new file mode 100644 index 00000000..1be3494f --- /dev/null +++ b/torchmultimodal/diffusion_labs/predictors/noise_predictor.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor +from torchmultimodal.diffusion_labs.predictors.predictor import Predictor +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) + + +class NoisePredictor(Predictor): + """Given a model that's trained to predict diffusion noise and corresponding schedule, + this class computes the predicted noise and x0 at step t. + + Attributes: + schedule (DiffusionSchedule): defines diffusion of noise through time + clamp_func (Callable): function to clamp prediction values + """ + + def __init__( + self, schedule: DiscreteGaussianSchedule, clamp_func: Optional[Callable] = None + ): + self.clamp_func = clamp_func + schedule.add_property("sqrt_recip_alphas_cumprod", _sqrt_recip_alphas_cumprod) + schedule.add_property( + "sqrt_recipm1_alphas_cumprod", _sqrt_recipm1_alphas_cumprod + ) + self.schedule = schedule + + def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: + shape, dtype = xt.shape, xt.dtype + x_coef = self.schedule("sqrt_recip_alphas_cumprod", t, shape) + e_coef = self.schedule("sqrt_recipm1_alphas_cumprod", t, shape) + x0 = x_coef * xt - e_coef * prediction + if self.clamp_func is not None: + x0 = self.clamp_func(x0) + return x0.to(dtype) + + def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: + return prediction + + +def _sqrt_recip_alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor: + # pyre-ignore + return (1.0 / schedule.alphas_cumprod).sqrt() + + +def _sqrt_recipm1_alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor: + # pyre-ignore + return (1.0 / schedule.alphas_cumprod - 1).sqrt() diff --git a/torchmultimodal/diffusion_labs/predictors/predictor.py b/torchmultimodal/diffusion_labs/predictors/predictor.py new file mode 100644 index 00000000..6f95d6d2 --- /dev/null +++ b/torchmultimodal/diffusion_labs/predictors/predictor.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from typing import Callable, Optional, Protocol, runtime_checkable + +from torch import Tensor +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) + + +@runtime_checkable +class Predictor(Protocol): + """Helper class to help predict various parts of the diffusion process. Different + implementations of each method are needed depending on what the model itself was + trained to predict. + """ + + schedule: DiscreteGaussianSchedule + clamp_func: Optional[Callable] + + @abstractmethod + def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: + """Predict x0 + + Args: + prediction (Tensor): model prediction + xt (Tensor): noised data to step t + t (Tensor): int diffusion step for xt + """ + + @abstractmethod + def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: + """Predict noise + + Args: + prediction (Tensor): model prediction + xt (Tensor): noised data to step t + t (Tensor): int diffusion step for xt + """ diff --git a/torchmultimodal/diffusion_labs/predictors/target_predictor.py b/torchmultimodal/diffusion_labs/predictors/target_predictor.py new file mode 100644 index 00000000..40c96c8d --- /dev/null +++ b/torchmultimodal/diffusion_labs/predictors/target_predictor.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor +from torchmultimodal.diffusion_labs.predictors.predictor import Predictor +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) + + +class TargetPredictor(Predictor): + """Given a model that's trained to predict x0 and corresponding schedule, + this class computes the predicted noise and x0 at step t. + + Attributes: + schedule (DiffusionSchedule): defines diffusion of noise through time + clamp_func (Callable): function to clamp prediction values + """ + + def __init__( + self, schedule: DiscreteGaussianSchedule, clamp_func: Optional[Callable] = None + ): + self.clamp_func = clamp_func + self.schedule = schedule + + def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: + if self.clamp_func is not None: + prediction = self.clamp_func(prediction) + return prediction + + def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: + shape, dtype = xt.shape, xt.dtype + x_coef = self.schedule("sqrt_recip_alphas_cumprod", t, shape) + e_coef = self.schedule("sqrt_recip_alphas_cumprod_minus_one", t, shape) + x0 = self.predict_x0(prediction, xt, t) + e = (x_coef * xt - x0) / e_coef + return e.to(dtype) diff --git a/torchmultimodal/diffusion_labs/samplers/__init__.py b/torchmultimodal/diffusion_labs/samplers/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/samplers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/modules/diffusion/ddim.py b/torchmultimodal/diffusion_labs/samplers/ddim.py similarity index 83% rename from torchmultimodal/modules/diffusion/ddim.py rename to torchmultimodal/diffusion_labs/samplers/ddim.py index 94e024d1..ff1d2da7 100644 --- a/torchmultimodal/modules/diffusion/ddim.py +++ b/torchmultimodal/diffusion_labs/samplers/ddim.py @@ -8,13 +8,15 @@ import torch from torch import nn, Tensor -from torchmultimodal.modules.diffusion.predictors import Predictor +from torchmultimodal.diffusion_labs.predictors.predictor import Predictor +from torchmultimodal.diffusion_labs.samplers.sampler import Sampler +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput -from torchmultimodal.modules.diffusion.schedules import DiffusionSchedule -from torchmultimodal.utils.diffusion_utils import DiffusionOutput - -class DDIModule(nn.Module): +class DDIModule(nn.Module, Sampler): """ DDIModule implements "Denoising Diffusion Implicit Models" presented by Song et. al (https://arxiv.org/abs/2010.02502). @@ -38,13 +40,13 @@ class DDIModule(nn.Module): Attributes: model (nn.Module): - schedule (DiffusionSchedule): defines noise diffusion throughout time + schedule (DiscreteGaussianSchedule): defines noise diffusion throughout time predictor (Predictor): used to help predict x0 eval_steps (Tensor): a subset of steps to sample at inference time eta (float): scaling factor used in Equation 12 of Song et. al (https://arxiv.org/abs/2010.02502) Args: - xt (Tensor): corrupted data at time t (when t = schedule.steps, xt is fully noise) + x (Tensor): corrupted data at time t (when t = schedule.steps, x is fully noise) of shape [b, c, ...] timestep (Tensor): diffusion step conditional_inputs (Dict): dictionary of context embeddings @@ -54,7 +56,7 @@ class DDIModule(nn.Module): def __init__( self, model: nn.Module, - schedule: DiffusionSchedule, + schedule: DiscreteGaussianSchedule, predictor: Predictor, eval_steps: Optional[Tensor] = None, progress_bar: bool = True, @@ -131,31 +133,31 @@ def remove_noise( def generator( self, - xt: Tensor, + x: Tensor, c: Optional[Dict[str, Tensor]] = None, ) -> Generator[Tensor, None, None]: """Generate xt for each t in self.eval_steps""" steps = self.eval_steps.flip(0) for step, next_step in zip(steps[:-1], steps[1:]): # Convert steps to batched tensors - t = step * torch.ones(xt.size(0), device=xt.device, dtype=torch.long) - t1 = next_step * torch.ones(xt.size(0), device=xt.device, dtype=torch.long) + t = step * torch.ones(x.size(0), device=x.device, dtype=torch.long) + t1 = next_step * torch.ones(x.size(0), device=x.device, dtype=torch.long) # Remove noise between step t and t+1 - xt = self.remove_noise(xt, c, t, t1) - yield xt + x = self.remove_noise(x, c, t, t1) + yield x def forward( self, - xt: Tensor, + x: Tensor, timestep: Optional[Tensor] = None, conditional_inputs: Optional[Dict[str, Tensor]] = None, ) -> Union[DiffusionOutput, Tensor]: if self.training: if timestep is None: raise ValueError("Must provide a timestep value during training") - return self.model(xt, timestep, conditional_inputs) + return self.model(x, timestep, conditional_inputs) else: - gen: Iterable = self.generator(xt, conditional_inputs) + gen: Iterable = self.generator(x, conditional_inputs) if self.progress_bar: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm diff --git a/torchmultimodal/modules/diffusion/ddpm.py b/torchmultimodal/diffusion_labs/samplers/ddpm.py similarity index 81% rename from torchmultimodal/modules/diffusion/ddpm.py rename to torchmultimodal/diffusion_labs/samplers/ddpm.py index 62794595..cd2cd246 100644 --- a/torchmultimodal/modules/diffusion/ddpm.py +++ b/torchmultimodal/diffusion_labs/samplers/ddpm.py @@ -10,13 +10,15 @@ import torch import torch.nn.functional as F from torch import nn, Tensor -from torchmultimodal.modules.diffusion.predictors import Predictor +from torchmultimodal.diffusion_labs.predictors.predictor import Predictor +from torchmultimodal.diffusion_labs.samplers.sampler import Sampler +from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import ( + DiscreteGaussianSchedule, +) +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput -from torchmultimodal.modules.diffusion.schedules import DiffusionSchedule -from torchmultimodal.utils.diffusion_utils import DiffusionOutput - -class DDPModule(nn.Module): +class DDPModule(nn.Module, Sampler): """DDPModule acts as a wrapper module around an inner neural network. During training it uses the inner neural network to predict a single denoising step. When set to eval, calling forward will sample the entire diffusion schedule. This module follows the denoising diffusion process as @@ -36,12 +38,12 @@ class DDPModule(nn.Module): Attributes: model (nn.Module): prediction neural network - schedule (DiffusionSchedule): defines diffusion of noise through time + schedule (DiscreteGaussianSchedule): defines diffusion of noise through time predictor (Predictor): predictor class to handle predictions depending on the model input eval_steps (Tensor): subset of steps to sample at inference Args: - xt (Tensor): corrupted data at time t (when t = schedule.steps, xt is equivalent to noise) + x (Tensor): corrupted data at time t (when t = schedule.steps, x is equivalent to noise) timestep (Tensor): diffusion step conditional_inputs (Dict): dictionary of context embeddings @@ -50,7 +52,7 @@ class DDPModule(nn.Module): def __init__( self, model: nn.Module, - schedule: DiffusionSchedule, + schedule: DiscreteGaussianSchedule, predictor: Predictor, eval_steps: Optional[Tensor] = None, progress_bar: bool = True, @@ -58,12 +60,12 @@ def __init__( super().__init__() self.model = model - self.train_schedule = schedule - self.train_predictor = predictor + self.schedule = schedule + self.predictor = predictor self.progress_bar = progress_bar if eval_steps is None: - eval_steps = torch.arange(self.train_schedule.steps) + eval_steps = torch.arange(self.schedule.steps) eval_steps_map = eval_steps self.eval_schedule = schedule self.eval_predictor = predictor @@ -73,7 +75,7 @@ def __init__( eval_steps, _ = eval_steps.sort() # eval_map maps from timestep in full schedule, to timestep in truncated eval scheule # e.g. if train has 1000 steps, and eval has 3, then t = 500 would map to eval timestep = 1 - eval_steps_map = torch.zeros(self.train_schedule.steps, dtype=torch.long) + eval_steps_map = torch.zeros(self.schedule.steps, dtype=torch.long) eval_steps_map[eval_steps] = torch.arange(len(eval_steps)) # Compute cumulative product of only the alphas in the eval steps @@ -103,8 +105,8 @@ def predict_parameters( t (Tensor): int diffusion steps """ pred, value = input.prediction, input.variance_value - schedule = self.train_schedule if self.training else self.eval_schedule - predictor = self.train_predictor if self.training else self.eval_predictor + schedule = self.schedule if self.training else self.eval_schedule + predictor = self.predictor if self.training else self.eval_predictor timestep = t if self.training else self.eval_steps_map[t] x0 = predictor.predict_x0(pred, xt, timestep) @@ -128,38 +130,38 @@ def remove_noise( # Predict x_{t-1} dtype = xt.dtype - noise = self.train_schedule.sample_noise(xt) + noise = self.schedule.sample_noise(xt) # Mask noise when t = 0; shape (b, 1, ..., 1) with same dims as xt nonzero_mask = (t != 0).to(dtype).view(-1, *([1] * (xt.dim() - 1))) # pyre-ignore return mean + nonzero_mask * (0.5 * log_variance).exp() * noise def generator( - self, xt: Tensor, c: Optional[Dict[str, Tensor]] = None + self, x: Tensor, c: Optional[Dict[str, Tensor]] = None ) -> Generator[Tensor, None, None]: """Generate xt for each t in sample_steps""" for step in self.eval_steps.flip(0): - t = step * torch.ones(xt.size(0), device=xt.device, dtype=torch.long) - xt = self.remove_noise(xt, t, c) - yield xt + t = step * torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = self.remove_noise(x, t, c) + yield x def forward( self, - xt: Tensor, + x: Tensor, timestep: Optional[Tensor] = None, conditional_inputs: Optional[Dict[str, Tensor]] = None, ) -> Union[DiffusionOutput, Tensor]: if self.training: if timestep is None: raise ValueError("Must provide a t value during training") - out = self.model(xt, timestep, conditional_inputs) + out = self.model(x, timestep, conditional_inputs) if not isinstance(out, DiffusionOutput): raise TypeError("Model is expected to output a DiffusionOutput class") if out.variance_value is not None: - out.mean, out.log_variance = self.predict_parameters(out, xt, timestep) + out.mean, out.log_variance = self.predict_parameters(out, x, timestep) return out else: - gen: Iterable = self.generator(xt, conditional_inputs) + gen: Iterable = self.generator(x, conditional_inputs) if self.progress_bar: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm diff --git a/torchmultimodal/diffusion_labs/samplers/sampler.py b/torchmultimodal/diffusion_labs/samplers/sampler.py new file mode 100644 index 00000000..e5fd9c4d --- /dev/null +++ b/torchmultimodal/diffusion_labs/samplers/sampler.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from typing import Dict, Generator, Optional, Protocol, runtime_checkable, Union + +from torch import nn, Tensor +from torchmultimodal.diffusion_labs.predictors.predictor import Predictor + +from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput + + +@runtime_checkable +class Sampler(Protocol): + """Sampler class for applying the learned denoising function given the diffusion schedule. + During training this class passes through the model outputs but during eval, the model loops + the model for all the eval steps. This class implements the same forward signature as the + Adapter class. To access individual generative steps at eval, the Sampler.generator() method + will return a python generator that steps through each denoising step. + + Example: + model = Sampler(...) + x = torch.randn(...) + + # Generate with forward + model.eval() + with torch.no_grad(): + img = model(x) + + # Generator with generator + model.eval() + gen = model.generator(x) + images = [] + with torch.no_grad(): + for i in gen: + images.append(i) + + img == images[-1] + """ + + model: nn.Module + schedule: DiffusionSchedule + predictor: Predictor + eval_steps: Tensor + + @abstractmethod + def generator( + self, + x: Tensor, + c: Optional[Dict[str, Tensor]] = None, + ) -> Generator[Tensor, None, None]: + """Generator for each t in self.eval_steps + + Args: + x (Tensor): corrupted data at time t (when t = schedule.steps, x is fully noise) + of shape [b, c, ...] + c (Dict): dictionary of model conditional inputs + """ + + @abstractmethod + def forward( + self, + x: Tensor, + timestep: Optional[Tensor] = None, + conditional_inputs: Optional[Dict[str, Tensor]] = None, + ) -> Union[DiffusionOutput, Tensor]: + """nn Module forward method + + Args: + x (Tensor): corrupted data at time t (when t = schedule.steps, x is fully noise) + of shape [b, c, ...] + timestep (Optional[Tensor]): diffusion step + conditional_inputs (Dict): dictionary of model conditional inputs + """ diff --git a/torchmultimodal/diffusion_labs/schedules/__init__.py b/torchmultimodal/diffusion_labs/schedules/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/schedules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/modules/diffusion/schedules.py b/torchmultimodal/diffusion_labs/schedules/discrete_gaussian_schedule.py similarity index 92% rename from torchmultimodal/modules/diffusion/schedules.py rename to torchmultimodal/diffusion_labs/schedules/discrete_gaussian_schedule.py index 27ea6bc4..7f835149 100644 --- a/torchmultimodal/modules/diffusion/schedules.py +++ b/torchmultimodal/diffusion_labs/schedules/discrete_gaussian_schedule.py @@ -11,15 +11,17 @@ import torch.nn.functional as F from torch import Tensor +from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule -class DiffusionSchedule: + +class DiscreteGaussianSchedule(DiffusionSchedule): """Diffusion is a thermondynamic process of two substances intermingling, likewise diffusion probabilistic models represent the transformation of one distribution into another as a gradual stochastic process. Specifically, in Denoising Diffusion, we model the transformation from the Gaussian distribution to some data distribution over a number of time steps. DiffusionSchedule is a parameterized helper class to model the changing distribution from Gaussian (noise) to data (x). This is an implementation of a diffusion schedule with discrete time steps. - DiffusionSchedule manages all timestep properties that are a function of the variance schedule (betas). For example, to + DiscreteGaussianSchedule manages all timestep properties that are a function of the variance schedule (betas). For example, to calculate the compliment of betas (called alphas in the paper), you define alphas def alphas(schedule): @@ -186,33 +188,33 @@ def __getattr__(self, name: str) -> Any: setattr(self, name, value) return value raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" + f"{self.__class__.__name__!r} object has no attribute {name!r}" ) # **************************************** Scheduler Functions **************************************** -def _alphas(schedule: DiffusionSchedule) -> Tensor: +def _alphas(schedule: DiscreteGaussianSchedule) -> Tensor: return 1.0 - schedule.betas -def _alphas_cumprod(schedule: DiffusionSchedule) -> Tensor: +def _alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor: return schedule.alphas.cumprod(axis=0) -def _alphas_cumprod_prev(schedule: DiffusionSchedule) -> Tensor: +def _alphas_cumprod_prev(schedule: DiscreteGaussianSchedule) -> Tensor: return F.pad(schedule.alphas_cumprod[:-1], (1, 0), value=1.0) -def _sqrt_alphas_cumprod(schedule: DiffusionSchedule) -> Tensor: +def _sqrt_alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor: return schedule.alphas_cumprod.sqrt() -def _sqrt_compliment_alphas_cumprod(schedule: DiffusionSchedule) -> Tensor: +def _sqrt_compliment_alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor: # pyre-ignore return (1.0 - schedule.alphas_cumprod).sqrt() -def _lower_posterior_log_variance(schedule: DiffusionSchedule) -> Tensor: +def _lower_posterior_log_variance(schedule: DiscreteGaussianSchedule) -> Tensor: # First element is 0 which has an infinite log (EQ 15 from Improving DDPMs) compliment_alphas_bar = 1.0 - schedule.alphas_cumprod compliment_alphas_bar_prev = 1.0 - schedule.alphas_cumprod_prev @@ -222,17 +224,17 @@ def _lower_posterior_log_variance(schedule: DiffusionSchedule) -> Tensor: return lpv.log() -def _upper_posterior_log_variance(schedule: DiffusionSchedule) -> Tensor: +def _upper_posterior_log_variance(schedule: DiscreteGaussianSchedule) -> Tensor: return schedule.betas.log() -def _posterior_mean_x0_coef(schedule: DiffusionSchedule) -> Tensor: +def _posterior_mean_x0_coef(schedule: DiscreteGaussianSchedule) -> Tensor: alphas_cumprod_prev_sqrt = schedule.alphas_cumprod_prev.sqrt() compliment_alphas_cumprod = 1.0 - schedule.alphas_cumprod return schedule.betas * alphas_cumprod_prev_sqrt / compliment_alphas_cumprod -def _posterior_mean_xt_coef(schedule: DiffusionSchedule) -> Tensor: +def _posterior_mean_xt_coef(schedule: DiscreteGaussianSchedule) -> Tensor: compliment_alphas_cumprod_prev = 1.0 - schedule.alphas_cumprod_prev alphas_sqrt = schedule.alphas.sqrt() compliment_alphas_cumprod = 1.0 - schedule.alphas_cumprod diff --git a/torchmultimodal/diffusion_labs/schedules/schedule.py b/torchmultimodal/diffusion_labs/schedules/schedule.py new file mode 100644 index 00000000..0713512b --- /dev/null +++ b/torchmultimodal/diffusion_labs/schedules/schedule.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod, abstractproperty +from typing import Protocol, runtime_checkable + +from torch import Tensor + + +@runtime_checkable +class DiffusionSchedule(Protocol): + """Class that defines the entire diffusion process and provides helper functions + for computing various transformations given the diffusion process + """ + + @abstractmethod + def sample_noise(self, x_like: Tensor) -> Tensor: + """Sample from diffusion distribution + + Args: + x_like (Tensor): example tensor to get meta properties for noise tensor + """ + + @abstractmethod + def sample_steps(self, x_like: Tensor) -> Tensor: + """Sample diffusion steps + + Args: + x_like (Tensor): example tensor to get meta properties for noise tensor + """ + + @abstractmethod + def q_sample(self, x0: Tensor, noise: Tensor, t: Tensor) -> Tensor: + """Given data (x at step 0) and noise, compute xt for the given + diffusion t. + + Args: + x0 (Tensor): uncorrupted data at step 0 + noise (Tensor): sample noise, same size as x0 + t (Tensor): int diffusion steps + """ + + @abstractproperty + def steps(self) -> int: + """Number of diffusion steps""" diff --git a/torchmultimodal/diffusion_labs/transforms/__init__.py b/torchmultimodal/diffusion_labs/transforms/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/transforms/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/transforms/diffusion_transform.py b/torchmultimodal/diffusion_labs/transforms/diffusion_transform.py new file mode 100644 index 00000000..568f6113 --- /dev/null +++ b/torchmultimodal/diffusion_labs/transforms/diffusion_transform.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +from torch import nn, Tensor +from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule + + +class RandomDiffusionSteps(nn.Module): + """Data Transform to randomly sample noised data from the diffusion schedule. + During diffusion training, random diffusion steps are sampled per model update. + This transform samples steps and returns the steps (t), seed noise, and transformed + data at time t (xt). + + Attributes: + schedule (DiffusionSchedule): defines diffusion of noise through time + batched (bool): if True, transform expects a batched input + + Args: + x (Tensor): data representing x0, artifact being learned. The 0 represents zero diffusion steps. + """ + + def __init__(self, schedule: DiffusionSchedule, batched: bool = True): + super().__init__() + self.schedule = schedule + self.batched = batched + + def __call__(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + if not self.batched: + t = self.schedule.sample_steps(x.unsqueeze(0)) + t = t.squeeze(0) + else: + t = self.schedule.sample_steps(x) + noise = self.schedule.sample_noise(x) + xt = self.schedule.q_sample(x, noise, t) + return x, xt, noise, t diff --git a/torchmultimodal/diffusion_labs/utils/__init__.py b/torchmultimodal/diffusion_labs/utils/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/utils/diffusion_utils.py b/torchmultimodal/diffusion_labs/utils/common.py similarity index 89% rename from torchmultimodal/utils/diffusion_utils.py rename to torchmultimodal/diffusion_labs/utils/common.py index 11e74a57..22e330c1 100644 --- a/torchmultimodal/utils/diffusion_utils.py +++ b/torchmultimodal/diffusion_labs/utils/common.py @@ -43,6 +43,11 @@ def cascaded_resize(pil_image: Image, resolution: int) -> Image: return pil_image +def normalize(x: Tensor, image_min: int, image_max: int) -> Tensor: + # Normalize image values between min and max + return (image_max - image_min) * x + image_min + + def denormalize_to_0_1(images: Tensor) -> Tensor: """Denormalize tensors from range [-1, 1] to [0, 1]""" denormed_images = torch.clamp((images + 1) / 2, 0, 1) diff --git a/torchmultimodal/modules/diffusion/predictors.py b/torchmultimodal/modules/diffusion/predictors.py deleted file mode 100644 index 8f561eeb..00000000 --- a/torchmultimodal/modules/diffusion/predictors.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import Callable, Optional - -from torch import Tensor -from torchmultimodal.modules.diffusion.schedules import DiffusionSchedule - - -class Predictor(ABC): - """Helper class to help predict various parts of the diffusion process. Different - implementations of each method are needed depending on what the model itself was - trained to predict. - """ - - schedule: DiffusionSchedule - clamp_func: Optional[Callable] - - @abstractmethod - def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: - """Predict x0 - - Args: - prediction (Tensor): model prediction - xt (Tensor): noised data to step t - t (Tensor): int diffusion step for xt - """ - pass - - @abstractmethod - def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: - """Predict noise - - Args: - prediction (Tensor): model prediction - xt (Tensor): noised data to step t - t (Tensor): int diffusion step for xt - """ - pass - - -class NoisePredictor(Predictor): - """Given a model that's trained to predict diffusion noise and corresponding schedule, - this class computes the predicted noise and x0 at step t. - - Attributes: - schedule (DiffusionSchedule): defines diffusion of noise through time - clamp_func (Callable): function to clamp prediction values - """ - - def __init__( - self, schedule: DiffusionSchedule, clamp_func: Optional[Callable] = None - ): - self.clamp_func = clamp_func - schedule.add_property("sqrt_recip_alphas_cumprod", _sqrt_recip_alphas_cumprod) - schedule.add_property( - "sqrt_recipm1_alphas_cumprod", _sqrt_recipm1_alphas_cumprod - ) - self.schedule = schedule - - def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: - shape, dtype = xt.shape, xt.dtype - x_coef = self.schedule("sqrt_recip_alphas_cumprod", t, shape) - e_coef = self.schedule("sqrt_recipm1_alphas_cumprod", t, shape) - x0 = x_coef * xt - e_coef * prediction - if self.clamp_func is not None: - x0 = self.clamp_func(x0) - return x0.to(dtype) - - def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: - return prediction - - -class TargetPredictor(Predictor): - """Given a model that's trained to predict x0 and corresponding schedule, - this class computes the predicted noise and x0 at step t. - - Attributes: - schedule (DiffusionSchedule): defines diffusion of noise through time - clamp_func (Callable): function to clamp prediction values - """ - - def __init__( - self, schedule: DiffusionSchedule, clamp_func: Optional[Callable] = None - ): - self.clamp_func = clamp_func - self.schedule = schedule - - def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: - if self.clamp_func is not None: - prediction = self.clamp_func(prediction) - return prediction - - def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor: - # TODO: For DDIM add predict_noise - pass - - -# TODO: Add VPredictor from https://arxiv.org/abs/2202.00512 - - -def _sqrt_recip_alphas_cumprod(schedule: DiffusionSchedule) -> Tensor: - # pyre-ignore - return (1.0 / schedule.alphas_cumprod).sqrt() - - -def _sqrt_recipm1_alphas_cumprod(schedule: DiffusionSchedule) -> Tensor: - # pyre-ignore - return (1.0 / schedule.alphas_cumprod - 1).sqrt()