Skip to content

Commit

Permalink
Update and Add Diffusion Components
Browse files Browse the repository at this point in the history
Summary: This diff brings in a number of updates as well as new features built since the last release to diffsuion_labs

Reviewed By: abhinavarora

Differential Revision: D50285167
  • Loading branch information
pbontrager authored and facebook-github-bot committed Oct 17, 2023
1 parent 9d4c8e7 commit 1bbb31a
Show file tree
Hide file tree
Showing 24 changed files with 1,270 additions and 108 deletions.
201 changes: 159 additions & 42 deletions tests/diffusion_labs/test_adapter_cfguidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Sequence, Tuple, Union

import pytest

import torch
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput

Expand All @@ -19,13 +21,24 @@ def set_seed():


@pytest.fixture
def params():
def params() -> Tuple[int, int, int]:
in_channels = 3
s = 4
embed_dim = 6
return in_channels, s, embed_dim


@pytest.fixture
def x(params) -> Tensor:
embed_dim = params[-1]
return torch.ones(1, embed_dim, params[1], params[1])


@pytest.fixture
def t() -> Tensor:
return torch.ones(1, 1)


class DummyADM(nn.Module):
def __init__(self, variance_flag=False):
super().__init__()
Expand All @@ -48,14 +61,19 @@ def forward(self, x, t, c):

class TestCFGuidance:
@pytest.fixture
def cond(self, params):
def dim_cond(self, params) -> Dict[str, Union[int, Sequence[int]]]:
embed_dim = params[-1]
c = torch.ones(1, embed_dim)
return {"test": c}
return {"test": embed_dim, "test_list": [embed_dim]}

@pytest.fixture
def dim_cond(self, cond):
return {k: v.shape[-1] for k, v in cond.items()}
def cond(self, dim_cond) -> Dict[str, Tensor]:
val = {}
for k, v in dim_cond.items():
if isinstance(v, int):
val[k] = torch.ones(1, v)
else:
val[k] = torch.ones(1, *v)
return val

@pytest.fixture
def models(self):
Expand All @@ -68,56 +86,155 @@ def cfguidance_models(self, models, dim_cond):
for k, model in models.items()
}

def test_training_forward(self, cfguidance_models, params, cond):
embed_dim = params[-1]
x = torch.ones(1, embed_dim, params[1], params[1])
t = torch.ones(1, 1)
@pytest.fixture
def cfguidance_models_p_1(self, models, dim_cond):
return {
k: CFGuidance(model, dim_cond, p=1, guidance=2, learn_null_emb=False)
for k, model in models.items()
}

@pytest.fixture
def cfguidance_models_train_embeddings(self, models, dim_cond, cond):
train_embeddings = {k: v * 5 for k, v in cond.items()}
return {
k: CFGuidance(
model,
dim_cond,
p=1,
guidance=2,
learn_null_emb=False,
train_unconditional_embeddings=train_embeddings,
)
for k, model in models.items()
}

@pytest.fixture
def cfguidance_models_eval_embeddings(self, models, dim_cond, cond):
eval_embeddings = {k: v * 3 for k, v in cond.items()}
return {
k: CFGuidance(
model,
dim_cond,
p=1,
guidance=2,
learn_null_emb=False,
eval_unconditional_embeddings=eval_embeddings,
)
for k, model in models.items()
}

@pytest.mark.parametrize(
"cfg_models,expected_multiplier",
[
("cfguidance_models", 6),
("cfguidance_models_p_1", 4),
("cfguidance_models_train_embeddings", 14),
("cfguidance_models_eval_embeddings", 4),
],
)
def test_training_forward(
self, cfg_models, x, t, cond, expected_multiplier, request
):
cfguidance_models = request.getfixturevalue(cfg_models)
for k, cfguidance_model in cfguidance_models.items():
actual = cfguidance_model(x, t, cond)
if k == "mean": # if adm model returns only prediction
expected = 4 * torch.ones(1, embed_dim, params[1], params[1])
assert_expected(actual.prediction, expected)
assert_expected(actual.variance_value, None)
elif (
k == "mean_variance"
): # if adm model returns both prediction and variance
expected_prediction = 4 * torch.ones(1, embed_dim, params[1], params[1])
expected_variance = torch.ones(1, embed_dim, params[1], params[1])
assert_expected(actual.prediction, expected_prediction)
assert_expected(actual.variance_value, expected_variance)

def test_inference_forward(self, cfguidance_models, params, cond):
embed_dim = params[-1]
x = torch.ones(1, embed_dim, params[1], params[1])
t = torch.ones(1, 1)

expected_mean = expected_multiplier * x
expected_variance = x if k == "mean_variance" else None
assert_expected(actual.prediction, expected_mean)
assert_expected(actual.variance_value, expected_variance)

@pytest.mark.parametrize(
"cfg_models,expected_multiplier",
[
("cfguidance_models", 10),
("cfguidance_models_p_1", 10),
("cfguidance_models_train_embeddings", -10),
("cfguidance_models_eval_embeddings", -2),
],
)
def test_inference_forward(
self, cfg_models, x, t, cond, expected_multiplier, request
):
cfguidance_models = request.getfixturevalue(cfg_models)
for k, cfguidance_model in cfguidance_models.items():
cfguidance_model.eval()
actual = cfguidance_model(x, t, cond)
expected_mean = expected_multiplier * x
expected_variance = x if k == "mean_variance" else None
assert_expected(actual.prediction, expected_mean)
assert_expected(actual.variance_value, expected_variance)

@pytest.mark.parametrize(
"cfg_models,expected_multiplier",
[
("cfguidance_models", 6),
("cfguidance_models_p_1", 6),
("cfguidance_models_train_embeddings", 6),
("cfguidance_models_eval_embeddings", 6),
],
)
def test_inference_0_guidance_forward(
self, cfg_models, x, t, cond, expected_multiplier, request
):
cfguidance_models = request.getfixturevalue(cfg_models)
for k, cfguidance_model in cfguidance_models.items():
cfguidance_model.guidance = 0
cfguidance_model.eval()
actual = cfguidance_model(x, t, cond)
if k == "mean": # if adm model returns only prediction
expected_prediction = 6 * torch.ones(1, embed_dim, params[1], params[1])
assert_expected(actual.prediction, expected_prediction)
assert_expected(actual.variance_value, None)
elif (
k == "mean_variance"
): # if adm model returns both prediction and variance
expected_prediction = 6 * torch.ones(1, embed_dim, params[1], params[1])
expected_variance = torch.ones(1, embed_dim, params[1], params[1])
assert_expected(actual.prediction, expected_prediction)
assert_expected(actual.variance_value, expected_variance)
expected_mean = expected_multiplier * x
expected_variance = x if k == "mean_variance" else None
assert_expected(actual.prediction, expected_mean)
assert_expected(actual.variance_value, expected_variance)

@pytest.mark.parametrize(
"cfg_models,expected_multiplier",
[
("cfguidance_models", 4),
("cfguidance_models_p_1", 4),
("cfguidance_models_train_embeddings", 14),
("cfguidance_models_eval_embeddings", 10),
],
)
def test_inference_no_cond_forward(
self, cfg_models, x, t, expected_multiplier, request
):
cfguidance_models = request.getfixturevalue(cfg_models)
for k, cfguidance_model in cfguidance_models.items():
cfguidance_model.eval()
actual = cfguidance_model(x, t, None)
expected_mean = expected_multiplier * x
expected_variance = x if k == "mean_variance" else None
assert_expected(actual.prediction, expected_mean)
assert_expected(actual.variance_value, expected_variance)

def test_get_prob_dict(self, cfguidance_models):
cfguidance_model = cfguidance_models["mean"]
actual = cfguidance_model._get_prob_dict(0.1)
expected = {"test": 0.1}
expected = {"test": 0.1, "test_list": 0.1}
assert_expected(actual, expected)

actual = cfguidance_model._get_prob_dict({"test": 0.1})
expected = {"test": 0.1}
actual = cfguidance_model._get_prob_dict({"test": 0.1, "test_list": 0.2})
expected = {"test": 0.1, "test_list": 0.2}
assert_expected(actual, expected)

with pytest.raises(ValueError):
actual = cfguidance_model._get_prob_dict({"test_2": 0.1, "test": 0.1})

with pytest.raises(TypeError):
actual = cfguidance_model._get_prob_dict("test")

def test_gen_unconditional_embeddings(self, cfguidance_models, params, cond):
cfguidance_model = cfguidance_models["mean"]
actual = cfguidance_model._gen_unconditional_embeddings(
None, torch.zeros, False
)
assert set(actual.keys()) == set(cond.keys())
for p in actual.values():
assert_expected(p.mean().item(), 0.0)

actual = cfguidance_model._gen_unconditional_embeddings(
cond, torch.zeros, False
)
assert set(actual.keys()) == set(cond.keys())
for p in actual.values():
assert_expected(p.mean().item(), 1.0)
2 changes: 1 addition & 1 deletion tests/diffusion_labs/test_dalle2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_dalle2_image_transform():
img_size = 5
transform = Dalle2ImageTransform(image_size=img_size, image_min=-1, image_max=1)
image = Image.new("RGB", size=(20, 20), color=(128, 0, 0))
actual = transform(image).sum()
actual = transform({"x": image})["x"].sum()
normalized128 = 128 / 255 * 2 - 1
normalized0 = -1
expected = torch.tensor(
Expand Down
5 changes: 4 additions & 1 deletion tests/diffusion_labs/test_diffusion_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ def sample_noise(self, x):
def q_sample(self, x, noise, t):
return x

def __call__(self, var_name, t, shape):
return torch.ones(shape)


def test_random_diffusion_steps():
transform = RandomDiffusionSteps(DummySchedule())
actual = len(transform(torch.ones(1)))
actual = len(transform({"x": torch.ones(1)}))
expected = 4
assert actual == expected, "Transform not returning correct keys"
109 changes: 109 additions & 0 deletions tests/diffusion_labs/test_inpainting_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import numpy as np

import torch
from PIL import Image
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.diffusion_labs.transforms.inpainting_transform import (
brush_stroke_mask_image,
draw_strokes,
generate_vertexes,
mask_full_image,
random_inpaint_mask_image,
random_outpaint_mask_image,
RandomInpaintingMask,
)

BATCH = 4
CHANNELS = 3
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256


def set_seed(seed: int):
set_rng_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class TestImageMasks(unittest.TestCase):
def setUp(self):
set_seed(1)
self.batch_images = torch.randn(BATCH, CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH)
self.image = self.batch_images[0, :, :, :]

def test_random_inpaint_mask_image(self):
set_seed(1)
mask = random_inpaint_mask_image(self.image)
self.assertIsInstance(mask, torch.Tensor)
self.assertEqual(mask.shape, (1, self.image.shape[-2], self.image.shape[-1]))
assert_expected(mask.sum(), torch.tensor(11524.0), rtol=0, atol=1e-4)

def test_random_outpaint_mask_image(self):
set_seed(1)
mask = random_outpaint_mask_image(self.image)
self.assertIsInstance(mask, torch.Tensor)
self.assertEqual(mask.shape, (1, self.image.shape[-2], self.image.shape[-1]))
assert_expected(mask.sum(), torch.tensor(27392.0), rtol=0, atol=1e-4)

def test_brush_stroke_mask_image(self):
set_seed(1)
mask = brush_stroke_mask_image(self.image)
self.assertIsInstance(mask, torch.Tensor)
self.assertEqual(mask.shape, (1, self.image.shape[-2], self.image.shape[-1]))
print(f"test_brush_stroke_mask_image: {mask.sum().item()}")
assert_expected(mask.sum(), torch.tensor(26860.0), rtol=0, atol=1e-4)

def test_mask_full_image(self):
set_seed(1)
mask = mask_full_image(self.image)
self.assertIsInstance(mask, torch.Tensor)
self.assertEqual(mask.shape, (1, self.image.shape[-2], self.image.shape[-1]))
self.assertTrue(torch.allclose(mask, torch.ones_like(mask)))
assert_expected(mask.sum(), torch.tensor(65536.0), rtol=0, atol=1e-4)

def test_generate_vertexes(self):
mask = Image.new("1", (IMAGE_WIDTH, IMAGE_HEIGHT), 0)
vertexes = generate_vertexes(
mask, num_vertexes=3, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT
)
self.assertIsInstance(vertexes, list)
self.assertEqual(len(vertexes), 4)
for vertex in vertexes:
self.assertIsInstance(vertex, tuple)
self.assertEqual(len(vertex), 2)
self.assertTrue(0 <= vertex[0] < IMAGE_WIDTH)
self.assertTrue(0 <= vertex[1] < IMAGE_HEIGHT)

def test_draw_strokes(self):
mask = Image.new("1", (IMAGE_WIDTH, IMAGE_HEIGHT), 0)
vertexes = [(10, 10), (20, 20), (30, 30)]
draw_strokes(mask, vertexes, width=2)
self.assertIsInstance(mask, Image.Image)

def test_generate_vertexes_and_draw_strokes(self):
mask = Image.new("1", (IMAGE_WIDTH, IMAGE_HEIGHT), 0)

vertexes = generate_vertexes(
mask, num_vertexes=3, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT
)
draw_strokes(mask, vertexes, width=2)
self.assertIsInstance(mask, Image.Image)

def test_random_mask(self):
random_mask = RandomInpaintingMask()
inpainting_mask = random_mask({"x": self.batch_images})["mask"]
assert inpainting_mask.shape == (BATCH, 1, IMAGE_HEIGHT, IMAGE_WIDTH)
assert torch.all(
torch.logical_or(inpainting_mask == 0.0, inpainting_mask == 1.0)
)
Loading

0 comments on commit 1bbb31a

Please sign in to comment.