diff --git a/tests/diffusion_labs/test_adm.py b/tests/diffusion_labs/test_adm.py index 2f7bc1f7..c1ad55a3 100644 --- a/tests/diffusion_labs/test_adm.py +++ b/tests/diffusion_labs/test_adm.py @@ -220,10 +220,13 @@ class TestADMStack: def model(self, params): in_dim, _, time_dim = params stack = ADMStack() - stack.append(ADMResBlock(in_dim, in_dim, time_dim, norm_groups=in_dim)) - stack.append(ADMAttentionBlock(in_dim, time_dim, norm_groups=in_dim)) - # To use the else statement in ADMStack - stack.append(nn.Identity()) + stack.append_residual_block( + ADMResBlock(in_dim, in_dim, time_dim, norm_groups=in_dim) + ) + stack.append_attention_block( + ADMAttentionBlock(in_dim, time_dim, norm_groups=in_dim) + ) + stack.append_simple_block(nn.Identity()) return stack def test_forward(self, model, x, t, c): diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py index 0b45893c..7d179378 100644 --- a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py @@ -4,19 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Dict, List, Optional, Tuple +from enum import Enum +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import nn, Tensor from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( adm_attn_block, - ADMAttentionBlock, ) from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, - ADMResBlock, ) from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm @@ -117,6 +116,7 @@ def __init__( variance_value_transform: Optional[Callable] = None, ): super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") if timestep_encoder is None: assert ( time_embed_dim is not None @@ -198,16 +198,15 @@ def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]: # Keep track of output channels of every block for thru connections to decoder down_channels = [] # Use ADMStack for conv layer so we can pass in conditional inputs and ignore them - init_conv: nn.ModuleList = ADMStack( - [ - nn.Conv2d( - self.in_channels, - self.channels_per_layer[0], - kernel_size=3, - stride=1, - padding=1, - ) - ] + init_conv = ADMStack() + init_conv.append_simple_block( + nn.Conv2d( + self.in_channels, + self.channels_per_layer[0], + kernel_size=3, + stride=1, + padding=1, + ) ) down_channels.append(self.channels_per_layer[0]) @@ -251,7 +250,7 @@ def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]: net = nn.ModuleList([init_conv] + stacks) return net, down_channels - def _create_bottleneck(self, num_channels: int) -> nn.ModuleList: + def _create_bottleneck(self, num_channels: int) -> nn.Module: in_resblock = adm_res_block( in_channels=num_channels, out_channels=num_channels, @@ -265,7 +264,11 @@ def _create_bottleneck(self, num_channels: int) -> nn.ModuleList: out_channels=num_channels, dim_cond=self.dim_res_cond, ) - return ADMStack([in_resblock, mid_attention, out_resblock]) + adm_stack = ADMStack() + adm_stack.append_residual_block(in_resblock) + adm_stack.append_attention_block(mid_attention) + adm_stack.append_residual_block(out_resblock) + return adm_stack def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList: # reverse so it's easier to iterate when going up the decoder @@ -303,15 +306,16 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList: layer_in_channels = layer_out_channels # Now create the down/upsampling res block if layer_num < self.num_resize: - stacks[-1].append( + stacks[-1].append_residual_block( adm_res_upsample_block( num_channels=layer_out_channels, dim_cond=self.dim_res_cond, ) ) - out_conv = ADMStack( - [ + out_conv = ADMStack() + out_conv.append_simple_block( + nn.Sequential( Fp32GroupNorm(32, up_channels_per_layer[-1]), nn.SiLU(), nn.Conv2d( @@ -321,7 +325,7 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList: stride=1, padding=1, ), - ] + ) ) net = nn.ModuleList(stacks + [out_conv]) @@ -415,38 +419,70 @@ def forward( ) -class ADMStack(nn.ModuleList): - """A container that acts as a ModuleList of ADM blocks and handles passing conditional inputs - correctly to the children ADMResBlocks and ADMAttentionBlocks. +class ADMStackModuleType(Enum): + ResidualBlock = 0 + AttentionBlock = 1 + SimpleBlock = 2 + + +class ADMStack(nn.Module): + """A container that acts like a ModuleList of ADM blocks and handles passing timestep and + context embeddings correctly to its children. Usually, blocks such as residual blocks consume + timestep embeddings, while attention blocks consume optional contextual embeddings in addition + to the input x. This container allows us to wrap the modules so that they can be stacked in a + `nn.Sequential`, in order to simplify the code for the `forward` method. + + We have to implement the stack in this way rather than inherting from `nn.ModuleList` to + avoid FSDP/Activation Checkpointing/PT2 incompatibility issues. + + Code ref: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35 """ + def __init__(self) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + self._module_list = nn.ModuleList() + self._module_types: List[ADMStackModuleType] = [] + + def append_attention_block(self, module: nn.Module) -> None: + self._module_list.append(module) + self._module_types.append(ADMStackModuleType.AttentionBlock) + + def append_residual_block(self, module: nn.Module) -> None: + self._module_list.append(module) + self._module_types.append(ADMStackModuleType.ResidualBlock) + + def append_simple_block(self, module: nn.Module) -> None: + self._module_list.append(module) + self._module_types.append(ADMStackModuleType.SimpleBlock) + def forward( self, x: Tensor, - res_conditional_embedding: Tensor, - attn_conditional_embedding: Tensor, + residual_conditional_embedding: Tensor, + attention_conditional_embedding: Optional[Union[Tensor, Sequence[Tensor]]], ) -> Tensor: h = x - for block in self: - if isinstance(block, ADMResBlock): - h = block(h, res_conditional_embedding) - elif isinstance(block, ADMAttentionBlock): - h = block(h, attn_conditional_embedding) + for name, block in zip(self._module_types, self._module_list): # noqa: B905 + if name == ADMStackModuleType.ResidualBlock: + h = block(h, residual_conditional_embedding) + elif name == ADMStackModuleType.AttentionBlock: + h = block(h, attention_conditional_embedding) else: h = block(h) return h -def adm_stack_res(in_channels: int, out_channels: int, dim_cond: int) -> nn.ModuleList: - return ADMStack( - [ - adm_res_block( - in_channels=in_channels, - out_channels=out_channels, - dim_cond=dim_cond, - ) - ] +def adm_stack_res(in_channels: int, out_channels: int, dim_cond: int) -> nn.Module: + adm_stack = ADMStack() + adm_stack.append_residual_block( + adm_res_block( + in_channels=in_channels, + out_channels=out_channels, + dim_cond=dim_cond, + ) ) + return adm_stack def adm_stack_res_attn( @@ -454,31 +490,33 @@ def adm_stack_res_attn( out_channels: int, dim_res_cond: int, dim_attn_cond: Optional[int] = None, -) -> nn.ModuleList: - return ADMStack( - [ - adm_res_block( - in_channels=in_channels, - out_channels=out_channels, - dim_cond=dim_res_cond, - ), - adm_attn_block( - num_channels=out_channels, - dim_cond=dim_attn_cond, - ), - ] +) -> nn.Module: + adm_stack = ADMStack() + adm_stack.append_residual_block( + adm_res_block( + in_channels=in_channels, + out_channels=out_channels, + dim_cond=dim_res_cond, + ) + ) + adm_stack.append_attention_block( + adm_attn_block( + num_channels=out_channels, + dim_cond=dim_attn_cond, + ) ) + return adm_stack -def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.ModuleList: - return ADMStack( - [ - adm_res_downsample_block( - num_channels=num_channels, - dim_cond=dim_cond, - ) - ] +def adm_stack_res_down(num_channels: int, dim_cond: int) -> nn.Module: + adm_stack = ADMStack() + adm_stack.append_residual_block( + adm_res_downsample_block( + num_channels=num_channels, + dim_cond=dim_cond, + ) ) + return adm_stack def adm_unet( diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py b/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py index f97426ac..cc1c4d32 100644 --- a/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py @@ -39,6 +39,7 @@ class ADMResBlock(nn.Module): Defaults to True. pre_outconv_dropout (float): dropout probability before the second conv. Defaults to 0.1. norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. + norm_eps (float): Epsilon used in the GroupNorm layer. Defaults to 1e-5. Args: x (Tensor): input Tensor of shape [b, c, h, w] @@ -62,6 +63,7 @@ def __init__( scale_shift_conditional: bool = True, pre_outconv_dropout: float = 0.1, norm_groups: int = 32, + norm_eps: float = 1e-5, ): super().__init__() @@ -93,13 +95,15 @@ def __init__( nn.Linear(dim_cond, cond_channels), ) self.in_block = nn.Sequential( - Fp32GroupNorm(norm_groups, in_channels), # groups = 32 from code ref + Fp32GroupNorm( + norm_groups, in_channels, eps=norm_eps + ), # groups = 32 from code ref activation, hidden_updownsample_layer, nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), ) + self.out_group_norm = Fp32GroupNorm(norm_groups, out_channels, eps=norm_eps) self.out_block = nn.Sequential( - Fp32GroupNorm(norm_groups, out_channels), activation, nn.Dropout(pre_outconv_dropout), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), @@ -126,12 +130,12 @@ def forward( # Use half to multiply with hidden state and half to add. # This is typically done after normalization. if self.scale_shift_conditional: - h = self.out_block[0](h) + h = self.out_group_norm(h) scale, shift = torch.chunk(t, 2, dim=1) h = h * (1 + scale) + shift - h = self.out_block[1:](h) + h = self.out_block(h) else: - h = self.out_block(h + t) + h = self.out_block(self.out_group_norm(h + t)) if self.rescale_skip_connection: h = (skip + h) / 1.414 else: