diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py index fc831cc9..3af2319d 100644 --- a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py @@ -192,7 +192,7 @@ def _create_attn_cond_proj( } ) - def _create_downsampling_encoder(self) -> Tuple[nn.Module, List]: + def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List]: # Keep track of output channels of every block for thru connections to decoder down_channels = [] # Use ADMStack for conv layer so we can pass in conditional inputs and ignore them @@ -268,7 +268,7 @@ def _create_bottleneck(self, num_channels: int) -> nn.Module: adm_stack.append_residual_block(out_resblock) return adm_stack - def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.Module: + def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList: # reverse so it's easier to iterate when going up the decoder up_channels_per_layer = list(reversed(self.channels_per_layer)) up_attention_for_layer = list(reversed(self.use_attention_for_layer)) @@ -430,21 +430,21 @@ class ADMStack(nn.Module): Code ref: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35 """ - def __init__(self): + def __init__(self) -> None: super().__init__() torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") self._module_list = nn.ModuleList() - self._module_types = [] + self._module_types: List[ADMStackModuleType] = [] - def append_attention_block(self, module: nn.Module): + def append_attention_block(self, module: nn.Module) -> None: self._module_list.append(module) self._module_types.append(ADMStackModuleType.AttentionBlock) - def append_residual_block(self, module: nn.Module): + def append_residual_block(self, module: nn.Module) -> None: self._module_list.append(module) self._module_types.append(ADMStackModuleType.ResidualBlock) - def append_simple_block(self, module: nn.Module): + def append_simple_block(self, module: nn.Module) -> None: self._module_list.append(module) self._module_types.append(ADMStackModuleType.SimpleBlock) @@ -453,7 +453,7 @@ def forward( x: Tensor, residual_conditional_embedding: Tensor, attention_conditional_embedding: Optional[Union[Tensor, Sequence[Tensor]]], - ): + ) -> Tensor: h = x for name, block in zip(self._module_types, self._module_list): # noqa: B905 if name == ADMStackModuleType.ResidualBlock: