Skip to content

Commit

Permalink
Update on "ADMStack and ADMResBlock changes"
Browse files Browse the repository at this point in the history
Differential Revision: [D50288363](https://our.internmc.facebook.com/intern/diff/D50288363)

[ghstack-poisoned]
  • Loading branch information
ebsmothers committed Oct 13, 2023
1 parent be428bd commit 556c53d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torchmultimodal/diffusion_labs/models/adm_unet/adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _create_attn_cond_proj(
}
)

def _create_downsampling_encoder(self) -> Tuple[nn.Module, List]:
def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]:
# Keep track of output channels of every block for thru connections to decoder
down_channels = []
# Use ADMStack for conv layer so we can pass in conditional inputs and ignore them
Expand Down Expand Up @@ -268,7 +268,7 @@ def _create_bottleneck(self, num_channels: int) -> nn.Module:
adm_stack.append_residual_block(out_resblock)
return adm_stack

def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.Module:
def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
# reverse so it's easier to iterate when going up the decoder
up_channels_per_layer = list(reversed(self.channels_per_layer))
up_attention_for_layer = list(reversed(self.use_attention_for_layer))
Expand Down Expand Up @@ -430,21 +430,21 @@ class ADMStack(nn.Module):
Code ref: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
self._module_list = nn.ModuleList()
self._module_types = []
self._module_types: List[ADMStackModuleType] = []

def append_attention_block(self, module: nn.Module):
def append_attention_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.AttentionBlock)

def append_residual_block(self, module: nn.Module):
def append_residual_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.ResidualBlock)

def append_simple_block(self, module: nn.Module):
def append_simple_block(self, module: nn.Module) -> None:
self._module_list.append(module)
self._module_types.append(ADMStackModuleType.SimpleBlock)

Expand All @@ -453,7 +453,7 @@ def forward(
x: Tensor,
residual_conditional_embedding: Tensor,
attention_conditional_embedding: Optional[Union[Tensor, Sequence[Tensor]]],
):
) -> Tensor:
h = x
for name, block in zip(self._module_types, self._module_list): # noqa: B905
if name == ADMStackModuleType.ResidualBlock:
Expand Down

0 comments on commit 556c53d

Please sign in to comment.