Skip to content

Commit

Permalink
Reorg OSS Diffusion Components to diffusion_labs folder (facebookrese…
Browse files Browse the repository at this point in the history
…arch#480)

Summary:
Based on this [proposal](https://docs.google.com/document/d/1GtN2urD8PiRr1X4COzvbVbNE8LWRrcoYkAO4v2aogO8/edit) to reorganize diffusion components and models under a new `diffusion_labs`. This is the first in a stack of diffs. This one only reorganizes what's already been moved to OSS.

This is primarily moving files with a couple of changes based on the proposal:
- predictors.py is split into a separate file per predictor
- adm is moved out of dalle2 to be it's own model adm_unet
- Dalle2ImageTransform is moved to dalle2 out of transforms
- schedule.py is renamed to discrete_guassian_schedule.py and an abstract DIffusionSchedule class was added
- An abstract adapter class was added to be a generic type and enforce the `forward` signature
- An abstract sampler class was added to be a generic type and enforce the `forward` and `generator` signature
- A new dalle2_model unit test was added

Differential Revision: D49790849

Pulled By: pbontrager

fbshipit-source-id: 98fe40c2418dc542cced940dc761a9cd602b0398
  • Loading branch information
Philip Bontrager authored and facebook-github-bot committed Oct 11, 2023
1 parent f2cfe1a commit b8226b9
Show file tree
Hide file tree
Showing 45 changed files with 654 additions and 348 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.modules.diffusion.cfguidance import CFGuidance
from torchmultimodal.utils.diffusion_utils import DiffusionOutput
from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput


@pytest.fixture(autouse=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.models.dalle2.adm.adm import ADM, ADMStack, ADMUNet
from torchmultimodal.models.dalle2.adm.attention_block import ADMAttentionBlock
from torchmultimodal.models.dalle2.adm.res_block import ADMResBlock
from torchmultimodal.diffusion_labs.models.adm_unet.adm import ADM, ADMStack, ADMUNet
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
ADMAttentionBlock,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ADMResBlock


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -116,7 +118,7 @@ def test_predict_variance_value_incorrect_channel_dim_error(


# All expected values come after first testing the ADMUNet has the exact output
# as the corresponding UNet class in d2go, then simply forward passing
# as the corresponding author UNet implementation, then simply forward passing
# ADMUNet with params, random seed, and initialization order in this file.
class TestADMUNet:
@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.models.dalle2.adm.attention_block import ADMAttentionBlock
from torchmultimodal.models.dalle2.adm.res_block import (
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
ADMAttentionBlock,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import (
adm_res_block,
adm_res_downsample_block,
adm_res_upsample_block,
Expand Down Expand Up @@ -47,7 +49,7 @@ def t(params):


# All expected values come after first testing the ADMResBlock has the exact output
# as the corresponding residual block class in d2go, then simply forward passing
# as the corresponding residual block class from ADM authors, then simply forward passing
# ADMResBlock with params, random seed, and initialization order in this file.
class TestADMResBlock:
@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.models.dalle2.adm.attention_block import (
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
adm_attention,
ADMCrossAttention,
)
Expand Down Expand Up @@ -44,7 +44,7 @@ def c(params):


# All expected values come after first testing that ADMCrossAttention has
# the exact output as the corresponding QKVAttention class in d2go, then simply forward passing
# the exact output as the corresponding QKVAttention class in ADM, then simply forward passing
# ADMCrossAttention with params, random seed, and initialization order in this file.
class TestADMCrossAttention:
@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

from PIL import Image
from torchmultimodal.utils.diffusion_utils import cascaded_resize
from torchmultimodal.diffusion_labs.utils.common import cascaded_resize


def test_cascaded_resize():
Expand Down
52 changes: 52 additions & 0 deletions tests/diffusion_labs/test_dalle2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from PIL import Image
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.diffusion_labs.models.dalle2.dalle2_decoder import dalle2_decoder
from torchmultimodal.diffusion_labs.models.dalle2.transforms import Dalle2ImageTransform


def test_dalle2_model():
set_rng_seed(4)
model = dalle2_decoder(
timesteps=1,
time_embed_dim=1,
cond_embed_dim=1,
clip_embed_dim=1,
clip_embed_name="clip_image",
predict_variance_value=True,
image_channels=1,
depth=32,
num_resize=1,
num_res_per_layer=1,
use_cf_guidance=True,
clip_image_guidance_dropout=0.1,
guidance_strength=7.0,
learn_null_emb=True,
)
model.eval()
x = torch.randn(1, 1, 4, 4)
c = torch.ones((1, 1))
with torch.no_grad():
actual = model(x, conditional_inputs={"clip_image": c}).mean()
expected = torch.as_tensor(0.12768)
assert_expected(actual, expected, rtol=0, atol=1e-4)


def test_dalle2_image_transform():
img_size = 5
transform = Dalle2ImageTransform(image_size=img_size, image_min=-1, image_max=1)
image = Image.new("RGB", size=(20, 20), color=(128, 0, 0))
actual = transform(image).sum()
normalized128 = 128 / 255 * 2 - 1
normalized0 = -1
expected = torch.tensor(
normalized128 * img_size**2 + 2 * normalized0 * img_size**2
)
assert_expected(actual, expected, rtol=0, atol=1e-4)
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
import torch
import torch.nn as nn
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.diffusion.schedules import (
DiffusionSchedule,
from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import (
DiffusionHybridLoss,
)
from torchmultimodal.diffusion_labs.modules.losses.vlb_loss import VLBLoss
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
linear_beta_schedule,
)
from torchmultimodal.modules.losses.diffusion import DiffusionHybridLoss, VLBLoss
from torchmultimodal.utils.diffusion_utils import DiffusionOutput
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput


@pytest.fixture(autouse=True)
Expand All @@ -24,7 +27,7 @@ def set_seed():

@pytest.fixture
def schedule():
return DiffusionSchedule(linear_beta_schedule(1000))
return DiscreteGaussianSchedule(linear_beta_schedule(1000))


@pytest.fixture
Expand All @@ -44,7 +47,7 @@ def target():


# All expected values come after first testing the HybridLoss has the exact output
# as the corresponding p_losses in D2Go Guassian Diffusion
# as the corresponding p_losses in Guassian Diffusion
class TestDiffusionHybridLoss:
@pytest.fixture
def loss(self, schedule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from PIL import Image
from tests.test_utils import assert_expected
from torchmultimodal.transforms.diffusion_transforms import (
Dalle2ImageTransform,
from torchmultimodal.diffusion_labs.transforms.diffusion_transform import (
RandomDiffusionSteps,
)

Expand All @@ -30,16 +27,3 @@ def test_random_diffusion_steps():
actual = len(transform(torch.ones(1)))
expected = 4
assert actual == expected, "Transform not returning correct keys"


def test_dalle_image_transform():
img_size = 5
transform = Dalle2ImageTransform(image_size=img_size, image_min=-1, image_max=1)
image = Image.new("RGB", size=(20, 20), color=(128, 0, 0))
actual = transform(image).sum()
normalized128 = 128 / 255 * 2 - 1
normalized0 = -1
expected = torch.tensor(
normalized128 * img_size**2 + 2 * normalized0 * img_size**2
)
assert_expected(actual, expected, rtol=0, atol=1e-4)
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import pytest
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.diffusion.schedules import (
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
cosine_beta_schedule,
DiffusionSchedule,
DiscreteGaussianSchedule,
linear_beta_schedule,
quadratic_beta_schedule,
sigmoid_beta_schedule,
Expand All @@ -23,11 +23,11 @@ def set_seed():


# All expected values come after first testing the Schedule has the exact output
# as the corresponding q methods from GaussianDiffusion in D2Go
# as the corresponding q methods from GaussianDiffusion
class TestDiffusionSchedule:
@pytest.fixture
def module(self):
schedule = DiffusionSchedule(linear_beta_schedule(1000))
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))
return schedule

@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import pytest
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.diffusion.predictors import NoisePredictor, TargetPredictor
from torchmultimodal.modules.diffusion.schedules import (
DiffusionSchedule,
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.predictors.target_predictor import TargetPredictor
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
linear_beta_schedule,
)

Expand All @@ -30,11 +31,11 @@ def input():


# All expected values come after first testing the Schedule has the exact output
# as the corresponding q methods from GaussianDiffusion in D2Go
# as the corresponding q methods from GaussianDiffusion
class TestNoisePredictor:
@pytest.fixture
def module(self):
schedule = DiffusionSchedule(linear_beta_schedule(1000))
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))
predictor = NoisePredictor(schedule, None)
return predictor

Expand All @@ -52,7 +53,7 @@ def test_predict_noise(self, module, input):
class TestTargetPredictor:
@pytest.fixture
def module(self):
schedule = DiffusionSchedule(linear_beta_schedule(1000))
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))
predictor = TargetPredictor(schedule, None)
return predictor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import torch
import torch.nn as nn
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.diffusion.ddim import DDIModule
from torchmultimodal.modules.diffusion.predictors import NoisePredictor
from torchmultimodal.modules.diffusion.schedules import (
DiffusionSchedule,
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.samplers.ddim import DDIModule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
linear_beta_schedule,
)
from torchmultimodal.utils.diffusion_utils import DiffusionOutput
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput


class DummyUNet(nn.Module):
Expand All @@ -38,7 +38,7 @@ class TestDDIModule:
@pytest.fixture
def module(self):
model = DummyUNet(True)
schedule = DiffusionSchedule(linear_beta_schedule(1000))
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))
predictor = NoisePredictor(schedule)
model = DDIModule(model, schedule, predictor)
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch.nn as nn
import torch.nn.functional as F
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.diffusion.ddpm import DDPModule
from torchmultimodal.modules.diffusion.predictors import NoisePredictor
from torchmultimodal.modules.diffusion.schedules import (
DiffusionSchedule,
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
linear_beta_schedule,
)
from torchmultimodal.utils.diffusion_utils import DiffusionOutput
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput


class DummyUNet(nn.Module):
Expand All @@ -36,13 +36,13 @@ def forward(self, x, t, c):


# All expected values come after first testing the Schedule has the exact output
# as the corresponding p methods from GaussianDiffusion in D2Go
# as the corresponding p methods from GaussianDiffusion
class TestDDPModule:
@pytest.fixture
def module(self):
set_rng_seed(4)
model = DummyUNet(True)
schedule = DiffusionSchedule(linear_beta_schedule(1000))
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))
predictor = NoisePredictor(schedule)
eval_steps = torch.arange(0, 1000, 50)
model = DDPModule(model, schedule, predictor, eval_steps)
Expand Down
File renamed without changes.
5 changes: 5 additions & 0 deletions torchmultimodal/diffusion_labs/models/adm_unet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Loading

0 comments on commit b8226b9

Please sign in to comment.