Skip to content

Commit

Permalink
Add LDM VAE and its components to diffusion_labs (facebookresearch#493)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#493

This diff moves LDM VAE and its components to diffusion_labs. Summary of changes:
1. Copy `ldm/autiencoder` files to `diffusion_labs/models/vae`.
1. Add sampling_layers.
1. Add residual_block
1. Add tests

Reviewed By: pbontrager

Differential Revision: D50353769

fbshipit-source-id: b6b283aaa8e799ce04ce4d7981cc1e233b3eb02c
  • Loading branch information
Abhinav Arora authored and facebook-github-bot committed Oct 18, 2023
1 parent 2ddb8cd commit cb7eb4c
Show file tree
Hide file tree
Showing 10 changed files with 1,451 additions and 0 deletions.
167 changes: 167 additions & 0 deletions tests/diffusion_labs/test_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
import torch.distributions as tdist
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.diffusion_labs.models.vae.vae import ldm_variational_autoencoder


@pytest.fixture(autouse=True)
def set_seed():
set_rng_seed(98765)


@pytest.fixture
def embedding_channels():
return 6


@pytest.fixture
def in_channels():
return 2


@pytest.fixture
def out_channels():
return 5


@pytest.fixture
def z_channels():
return 3


@pytest.fixture
def channels():
return 4


@pytest.fixture
def num_res_blocks():
return 2


@pytest.fixture
def channel_multipliers():
return (1, 2, 4)


@pytest.fixture
def norm_groups():
return 2


@pytest.fixture
def norm_eps():
return 1e-05


@pytest.fixture
def x(in_channels):
bsize = 2
height = 16
width = 16
return torch.randn(bsize, in_channels, height, width)


@pytest.fixture
def z(embedding_channels):
bsize = 2
height = 4
width = 4
return torch.randn(bsize, embedding_channels, height, width)


class TestVariationalAutoencoder:
@pytest.fixture
def vae(
self,
in_channels,
out_channels,
embedding_channels,
z_channels,
channels,
norm_groups,
norm_eps,
channel_multipliers,
num_res_blocks,
):
return ldm_variational_autoencoder(
embedding_channels=embedding_channels,
in_channels=in_channels,
out_channels=out_channels,
z_channels=z_channels,
channels=channels,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
norm_groups=norm_groups,
norm_eps=norm_eps,
)

def test_encode(self, vae, x, embedding_channels, channel_multipliers):
posterior = vae.encode(x)
expected_shape = torch.Size(
[
x.size(0),
embedding_channels,
x.size(2) // 2 ** (len(channel_multipliers) - 1),
x.size(3) // 2 ** (len(channel_multipliers) - 1),
]
)
expected_mean = torch.tensor(-3.4872)
assert_expected(posterior.mean.size(), expected_shape)
assert_expected(posterior.mean.sum(), expected_mean, rtol=0, atol=1e-4)

expected_stddev = torch.tensor(193.3726)
assert_expected(posterior.stddev.size(), expected_shape)
assert_expected(posterior.stddev.sum(), expected_stddev, rtol=0, atol=1e-4)

# compute kl with standard gaussian
expected_kl = torch.tensor(9.8025)
torch_kl_divergence = tdist.kl_divergence(
posterior,
tdist.Normal(
torch.zeros_like(posterior.mean), torch.ones_like(posterior.stddev)
),
).sum()
assert_expected(torch_kl_divergence, expected_kl, rtol=0, atol=1e-4)

# compare sample shape
assert_expected(posterior.rsample().size(), expected_shape)

def test_decode(self, vae, z, out_channels, channel_multipliers):
actual = vae.decode(z)
expected = torch.tensor(-156.1534)
expected_shape = torch.Size(
[
z.size(0),
out_channels,
z.size(2) * 2 ** (len(channel_multipliers) - 1),
z.size(3) * 2 ** (len(channel_multipliers) - 1),
]
)
assert_expected(actual.size(), expected_shape)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)

@pytest.mark.parametrize(
"sample_posterior,expected_value", [(True, -153.6517), (False, -178.8593)]
)
def test_forward(self, vae, x, out_channels, sample_posterior, expected_value):
actual = vae(x, sample_posterior=sample_posterior).decoder_output
expected = torch.tensor(expected_value)
expected_shape = torch.Size(
[
x.size(0),
out_channels,
x.size(2),
x.size(3),
]
)
assert_expected(actual.size(), expected_shape)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
85 changes: 85 additions & 0 deletions tests/diffusion_labs/test_vae_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.diffusion_labs.models.vae.attention import (
attention_res_block,
AttentionResBlock,
VanillaAttention,
)


@pytest.fixture(autouse=True)
def set_seed():
set_rng_seed(1234)


@pytest.fixture
def channels():
return 64


@pytest.fixture
def norm_groups():
return 16


@pytest.fixture
def norm_eps():
return 1e-05


@pytest.fixture
def x(channels):
bsize = 2
height = 16
width = 16
return torch.randn(bsize, channels, height, width)


class TestVanillaAttention:
@pytest.fixture
def attn(self, channels):
return VanillaAttention(channels)

def test_forward(self, x, attn):
actual = attn(x)
expected = torch.tensor(32.0883)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
assert_expected(actual.shape, x.shape)


class TestAttentionResBlock:
@pytest.fixture
def attn(self, channels, norm_groups, norm_eps):
return AttentionResBlock(
channels,
attn_module=nn.Identity(),
norm_groups=norm_groups,
norm_eps=norm_eps,
)

def test_forward(self, x, attn):
actual = attn(x)
expected = torch.tensor(295.1067)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
assert_expected(actual.shape, x.shape)

def test_channel_indivisible_norm_group_error(self):
with pytest.raises(ValueError):
_ = AttentionResBlock(64, nn.Identity(), norm_groups=30)


def test_attention_res_block(channels, x):
attn = attention_res_block(channels)
expected = torch.tensor(69.692)
actual = attn(x)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
assert_expected(actual.shape, x.shape)
Loading

0 comments on commit cb7eb4c

Please sign in to comment.