Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADMStack and ADMResBlock changes #489

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tests/diffusion_labs/test_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,13 @@ class TestADMStack:
def model(self, params):
in_dim, _, time_dim = params
stack = ADMStack()
stack.append(ADMResBlock(in_dim, in_dim, time_dim, norm_groups=in_dim))
stack.append(ADMAttentionBlock(in_dim, time_dim, norm_groups=in_dim))
# To use the else statement in ADMStack
stack.append(nn.Identity())
stack.append_residual_block(
ADMResBlock(in_dim, in_dim, time_dim, norm_groups=in_dim)
)
stack.append_attention_block(
ADMAttentionBlock(in_dim, time_dim, norm_groups=in_dim)
)
stack.append_simple_block(nn.Identity())
return stack

def test_forward(self, model, x, t, c):
Expand Down
156 changes: 97 additions & 59 deletions torchmultimodal/diffusion_labs/models/adm_unet/adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, List, Optional, Tuple
from enum import Enum
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
adm_attn_block,
ADMAttentionBlock,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import (
adm_res_block,
adm_res_downsample_block,
adm_res_upsample_block,
ADMResBlock,
)
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput
from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm
Expand Down Expand Up @@ -117,6 +116,7 @@ def __init__(
variance_value_transform: Optional[Callable] = None,
):
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
if timestep_encoder is None:
assert (
time_embed_dim is not None
Expand Down Expand Up @@ -198,16 +198,15 @@ def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]:
# Keep track of output channels of every block for thru connections to decoder
down_channels = []
# Use ADMStack for conv layer so we can pass in conditional inputs and ignore them
init_conv: nn.ModuleList = ADMStack(
[
nn.Conv2d(
self.in_channels,
self.channels_per_layer[0],
kernel_size=3,
stride=1,
padding=1,
)
]
init_conv = ADMStack()
init_conv.append_simple_block(
nn.Conv2d(
self.in_channels,
self.channels_per_layer[0],
kernel_size=3,
stride=1,
padding=1,
)
)
down_channels.append(self.channels_per_layer[0])

Expand Down Expand Up @@ -251,7 +250,7 @@ def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]:
net = nn.ModuleList([init_conv] + stacks)
return net, down_channels

def _create_bottleneck(self, num_channels: int) -> nn.ModuleList:
def _create_bottleneck(self, num_channels: int) -> nn.Module:
in_resblock = adm_res_block(
in_channels=num_channels,
out_channels=num_channels,
Expand All @@ -265,7 +264,11 @@ def _create_bottleneck(self, num_channels: int) -> nn.ModuleList:
out_channels=num_channels,
dim_cond=self.dim_res_cond,
)
return ADMStack([in_resblock, mid_attention, out_resblock])
adm_stack = ADMStack()
adm_stack.append_residual_block(in_resblock)
adm_stack.append_attention_block(mid_attention)
adm_stack.append_residual_block(out_resblock)
return adm_stack

def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
# reverse so it's easier to iterate when going up the decoder
Expand Down Expand Up @@ -303,15 +306,16 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
layer_in_channels = layer_out_channels
# Now create the down/upsampling res block
if layer_num < self.num_resize:
stacks[-1].append(
stacks[-1].append_residual_block(
adm_res_upsample_block(
num_channels=layer_out_channels,
dim_cond=self.dim_res_cond,
)
)

out_conv = ADMStack(
[
out_conv = ADMStack()
out_conv.append_simple_block(
nn.Sequential(
Fp32GroupNorm(32, up_channels_per_layer[-1]),
nn.SiLU(),
nn.Conv2d(
Expand All @@ -321,7 +325,7 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
stride=1,
padding=1,
),
]
)
)

net = nn.ModuleList(stacks + [out_conv])
Expand Down Expand Up @@ -415,70 +419,104 @@ def forward(
)


class ADMStack(nn.ModuleList):
"""A container that acts as a ModuleList of ADM blocks and handles passing conditional inputs
correctly to the children ADMResBlocks and ADMAttentionBlocks.
class ADMStackModuleType(Enum):
ResidualBlock = 0
AttentionBlock = 1
SimpleBlock = 2


class ADMStack(nn.Module):
"""A container that acts like a ModuleList of ADM blocks and handles passing timestep and
context embeddings correctly to its children. Usually, blocks such as residual blocks consume
timestep embeddings, while attention blocks consume optional contextual embeddings in addition
to the input x. This container allows us to wrap the modules so that they can be stacked in a
`nn.Sequential`, in order to simplify the code for the `forward` method.

We have to implement the stack in this way rather than inherting from `nn.ModuleList` to
avoid FSDP/Activation Checkpointing/PT2 incompatibility issues.

Code ref: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35
"""

def __init__(self) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
self._module_list = nn.ModuleList()
self._module_types: List[ADMStackModuleType] = []

def append_attention_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.AttentionBlock)

def append_residual_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.ResidualBlock)

def append_simple_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.SimpleBlock)

def forward(
self,
x: Tensor,
res_conditional_embedding: Tensor,
attn_conditional_embedding: Tensor,
residual_conditional_embedding: Tensor,
attention_conditional_embedding: Optional[Union[Tensor, Sequence[Tensor]]],
) -> Tensor:
h = x
for block in self:
if isinstance(block, ADMResBlock):
h = block(h, res_conditional_embedding)
elif isinstance(block, ADMAttentionBlock):
h = block(h, attn_conditional_embedding)
for name, block in zip(self._module_types, self._module_list): # noqa: B905
if name == ADMStackModuleType.ResidualBlock:
h = block(h, residual_conditional_embedding)
elif name == ADMStackModuleType.AttentionBlock:
h = block(h, attention_conditional_embedding)
else:
h = block(h)
return h


def adm_stack_res(in_channels: int, out_channels: int, dim_cond: int) -> nn.ModuleList:
return ADMStack(
[
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_cond,
)
]
def adm_stack_res(in_channels: int, out_channels: int, dim_cond: int) -> nn.Module:
adm_stack = ADMStack()
adm_stack.append_residual_block(
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_cond,
)
)
return adm_stack


def adm_stack_res_attn(
in_channels: int,
out_channels: int,
dim_res_cond: int,
dim_attn_cond: Optional[int] = None,
) -> nn.ModuleList:
return ADMStack(
[
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_res_cond,
),
adm_attn_block(
num_channels=out_channels,
dim_cond=dim_attn_cond,
),
]
) -> nn.Module:
adm_stack = ADMStack()
adm_stack.append_residual_block(
adm_res_block(
in_channels=in_channels,
out_channels=out_channels,
dim_cond=dim_res_cond,
)
)
adm_stack.append_attention_block(
adm_attn_block(
num_channels=out_channels,
dim_cond=dim_attn_cond,
)
)
return adm_stack


def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.ModuleList:
return ADMStack(
[
adm_res_downsample_block(
num_channels=num_channels,
dim_cond=dim_cond,
)
]
def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.Module:
adm_stack = ADMStack()
adm_stack.append_residual_block(
adm_res_downsample_block(
num_channels=num_channels,
dim_cond=dim_cond,
)
)
return adm_stack


def adm_unet(
Expand Down
14 changes: 9 additions & 5 deletions torchmultimodal/diffusion_labs/models/adm_unet/res_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ADMResBlock(nn.Module):
Defaults to True.
pre_outconv_dropout (float): dropout probability before the second conv. Defaults to 0.1.
norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32.
norm_eps (float): Epsilon used in the GroupNorm layer. Defaults to 1e-5.

Args:
x (Tensor): input Tensor of shape [b, c, h, w]
Expand All @@ -62,6 +63,7 @@ def __init__(
scale_shift_conditional: bool = True,
pre_outconv_dropout: float = 0.1,
norm_groups: int = 32,
norm_eps: float = 1e-5,
):
super().__init__()

Expand Down Expand Up @@ -93,13 +95,15 @@ def __init__(
nn.Linear(dim_cond, cond_channels),
)
self.in_block = nn.Sequential(
Fp32GroupNorm(norm_groups, in_channels), # groups = 32 from code ref
Fp32GroupNorm(
norm_groups, in_channels, eps=norm_eps
), # groups = 32 from code ref
activation,
hidden_updownsample_layer,
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
)
self.out_group_norm = Fp32GroupNorm(norm_groups, out_channels, eps=norm_eps)
self.out_block = nn.Sequential(
Fp32GroupNorm(norm_groups, out_channels),
activation,
nn.Dropout(pre_outconv_dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
Expand All @@ -126,12 +130,12 @@ def forward(
# Use half to multiply with hidden state and half to add.
# This is typically done after normalization.
if self.scale_shift_conditional:
h = self.out_block[0](h)
h = self.out_group_norm(h)
scale, shift = torch.chunk(t, 2, dim=1)
h = h * (1 + scale) + shift
h = self.out_block[1:](h)
h = self.out_block(h)
else:
h = self.out_block(h + t)
h = self.out_block(self.out_group_norm(h + t))
if self.rescale_skip_connection:
h = (skip + h) / 1.414
else:
Expand Down
Loading