Skip to content

Commit

Permalink
ADMStack and ADMResBlock changes (facebookresearch#489)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#489

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: pbontrager

Differential Revision: D50288363

Pulled By: ebsmothers

fbshipit-source-id: f63ec4fe157987946f03d302845fff2829deb5d1
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Oct 17, 2023
1 parent 367130e commit 9d4c8e7
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 68 deletions.
11 changes: 7 additions & 4 deletions tests/diffusion_labs/test_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,13 @@ class TestADMStack:
def model(self, params):
in_dim, _, time_dim = params
stack = ADMStack()
stack.append(ADMResBlock(in_dim, in_dim, time_dim, norm_groups=in_dim))
stack.append(ADMAttentionBlock(in_dim, time_dim, norm_groups=in_dim))
# To use the else statement in ADMStack
stack.append(nn.Identity())
stack.append_residual_block(
ADMResBlock(in_dim, in_dim, time_dim, norm_groups=in_dim)
)
stack.append_attention_block(
ADMAttentionBlock(in_dim, time_dim, norm_groups=in_dim)
)
stack.append_simple_block(nn.Identity())
return stack

def test_forward(self, model, x, t, c):
Expand Down
156 changes: 97 additions & 59 deletions torchmultimodal/diffusion_labs/models/adm_unet/adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, List, Optional, Tuple
from enum import Enum
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
adm_attn_block,
ADMAttentionBlock,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import (
adm_res_block,
adm_res_downsample_block,
adm_res_upsample_block,
ADMResBlock,
)
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput
from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm
Expand Down Expand Up @@ -117,6 +116,7 @@ def __init__(
variance_value_transform: Optional[Callable] = None,
):
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
if timestep_encoder is None:
assert (
time_embed_dim is not None
Expand Down Expand Up @@ -198,16 +198,15 @@ def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]:
# Keep track of output channels of every block for thru connections to decoder
down_channels = []
# Use ADMStack for conv layer so we can pass in conditional inputs and ignore them
init_conv: nn.ModuleList = ADMStack(
[
nn.Conv2d(
self.in_channels,
self.channels_per_layer[0],
kernel_size=3,
stride=1,
padding=1,
)
]
init_conv = ADMStack()
init_conv.append_simple_block(
nn.Conv2d(
self.in_channels,
self.channels_per_layer[0],
kernel_size=3,
stride=1,
padding=1,
)
)
down_channels.append(self.channels_per_layer[0])

Expand Down Expand Up @@ -251,7 +250,7 @@ def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]:
net = nn.ModuleList([init_conv] + stacks)
return net, down_channels

def _create_bottleneck(self, num_channels: int) -> nn.ModuleList:
def _create_bottleneck(self, num_channels: int) -> nn.Module:
in_resblock = adm_res_block(
in_channels=num_channels,
out_channels=num_channels,
Expand All @@ -265,7 +264,11 @@ def _create_bottleneck(self, num_channels: int) -> nn.ModuleList:
out_channels=num_channels,
dim_cond=self.dim_res_cond,
)
return ADMStack([in_resblock, mid_attention, out_resblock])
adm_stack = ADMStack()
adm_stack.append_residual_block(in_resblock)
adm_stack.append_attention_block(mid_attention)
adm_stack.append_residual_block(out_resblock)
return adm_stack

def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
# reverse so it's easier to iterate when going up the decoder
Expand Down Expand Up @@ -303,15 +306,16 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
layer_in_channels = layer_out_channels
# Now create the down/upsampling res block
if layer_num < self.num_resize:
stacks[-1].append(
stacks[-1].append_residual_block(
adm_res_upsample_block(
num_channels=layer_out_channels,
dim_cond=self.dim_res_cond,
)
)

out_conv = ADMStack(
[
out_conv = ADMStack()
out_conv.append_simple_block(
nn.Sequential(
Fp32GroupNorm(32, up_channels_per_layer[-1]),
nn.SiLU(),
nn.Conv2d(
Expand All @@ -321,7 +325,7 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
stride=1,
padding=1,
),
]
)
)

net = nn.ModuleList(stacks + [out_conv])
Expand Down Expand Up @@ -415,70 +419,104 @@ def forward(
)


class ADMStack(nn.ModuleList):
"""A container that acts as a ModuleList of ADM blocks and handles passing conditional inputs
correctly to the children ADMResBlocks and ADMAttentionBlocks.
class ADMStackModuleType(Enum):
ResidualBlock = 0
AttentionBlock = 1
SimpleBlock = 2


class ADMStack(nn.Module):
"""A container that acts like a ModuleList of ADM blocks and handles passing timestep and
context embeddings correctly to its children. Usually, blocks such as residual blocks consume
timestep embeddings, while attention blocks consume optional contextual embeddings in addition
to the input x. This container allows us to wrap the modules so that they can be stacked in a
`nn.Sequential`, in order to simplify the code for the `forward` method.
We have to implement the stack in this way rather than inherting from `nn.ModuleList` to
avoid FSDP/Activation Checkpointing/PT2 incompatibility issues.
Code ref: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35
"""

def __init__(self) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
self._module_list = nn.ModuleList()
self._module_types: List[ADMStackModuleType] = []

def append_attention_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.AttentionBlock)

def append_residual_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.ResidualBlock)

def append_simple_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.SimpleBlock)

def forward(
self,
x: Tensor,
res_conditional_embedding: Tensor,
attn_conditional_embedding: Tensor,
residual_conditional_embedding: Tensor,
attention_conditional_embedding: Optional[Union[Tensor, Sequence[Tensor]]],
) -> Tensor:
h = x
for block in self:
if isinstance(block, ADMResBlock):
h = block(h, res_conditional_embedding)
elif isinstance(block, ADMAttentionBlock):
h = block(h, attn_conditional_embedding)
for name, block in zip(self._module_types, self._module_list): # noqa: B905
if name == ADMStackModuleType.ResidualBlock:
h = block(h, residual_conditional_embedding)
elif name == ADMStackModuleType.AttentionBlock:
h = block(h, attention_conditional_embedding)
else:
h = block(h)
return h


def adm_stack_res(in_channels: int, out_channels: int, dim_cond: int) -> nn.ModuleList:
return ADMStack(
[
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_cond,
)
]
def adm_stack_res(in_channels: int, out_channels: int, dim_cond: int) -> nn.Module:
adm_stack = ADMStack()
adm_stack.append_residual_block(
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_cond,
)
)
return adm_stack


def adm_stack_res_attn(
in_channels: int,
out_channels: int,
dim_res_cond: int,
dim_attn_cond: Optional[int] = None,
) -> nn.ModuleList:
return ADMStack(
[
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_res_cond,
),
adm_attn_block(
num_channels=out_channels,
dim_cond=dim_attn_cond,
),
]
) -> nn.Module:
adm_stack = ADMStack()
adm_stack.append_residual_block(
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_res_cond,
)
)
adm_stack.append_attention_block(
adm_attn_block(
num_channels=out_channels,
dim_cond=dim_attn_cond,
)
)
return adm_stack


def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.ModuleList:
return ADMStack(
[
adm_res_downsample_block(
num_channels=num_channels,
dim_cond=dim_cond,
)
]
def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.Module:
adm_stack = ADMStack()
adm_stack.append_residual_block(
adm_res_downsample_block(
num_channels=num_channels,
dim_cond=dim_cond,
)
)
return adm_stack


def adm_unet(
Expand Down
14 changes: 9 additions & 5 deletions torchmultimodal/diffusion_labs/models/adm_unet/res_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ADMResBlock(nn.Module):
Defaults to True.
pre_outconv_dropout (float): dropout probability before the second conv. Defaults to 0.1.
norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32.
norm_eps (float): Epsilon used in the GroupNorm layer. Defaults to 1e-5.
Args:
x (Tensor): input Tensor of shape [b, c, h, w]
Expand All @@ -62,6 +63,7 @@ def __init__(
scale_shift_conditional: bool = True,
pre_outconv_dropout: float = 0.1,
norm_groups: int = 32,
norm_eps: float = 1e-5,
):
super().__init__()

Expand Down Expand Up @@ -93,13 +95,15 @@ def __init__(
nn.Linear(dim_cond, cond_channels),
)
self.in_block = nn.Sequential(
Fp32GroupNorm(norm_groups, in_channels), # groups = 32 from code ref
Fp32GroupNorm(
norm_groups, in_channels, eps=norm_eps
), # groups = 32 from code ref
activation,
hidden_updownsample_layer,
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
)
self.out_group_norm = Fp32GroupNorm(norm_groups, out_channels, eps=norm_eps)
self.out_block = nn.Sequential(
Fp32GroupNorm(norm_groups, out_channels),
activation,
nn.Dropout(pre_outconv_dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
Expand All @@ -126,12 +130,12 @@ def forward(
# Use half to multiply with hidden state and half to add.
# This is typically done after normalization.
if self.scale_shift_conditional:
h = self.out_block[0](h)
h = self.out_group_norm(h)
scale, shift = torch.chunk(t, 2, dim=1)
h = h * (1 + scale) + shift
h = self.out_block[1:](h)
h = self.out_block(h)
else:
h = self.out_block(h + t)
h = self.out_block(self.out_group_norm(h + t))
if self.rescale_skip_connection:
h = (skip + h) / 1.414
else:
Expand Down

0 comments on commit 9d4c8e7

Please sign in to comment.