Skip to content

Commit

Permalink
fix unit test error, add attribute latent_channels for vae
Browse files Browse the repository at this point in the history
  • Loading branch information
Taited committed Aug 10, 2023
1 parent bc9f677 commit 78316ce
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
9 changes: 7 additions & 2 deletions mmagic/models/editors/ddpm/denoising_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ class DenoisingUnet(BaseModule):
image_size (int | list[int]): The size of image to denoise.
in_channels (int, optional): The input channels of the input image.
Defaults as ``3``.
out_channels (int, optional): The output channels of the output
prediction. Defaults as ``None`` for automaticaaly assigned by
``var_mode``.
base_channels (int, optional): The basic channel number of the
generator. The other layers contain channels based on this number.
Defaults to ``128``.
Expand Down Expand Up @@ -837,6 +840,7 @@ class DenoisingUnet(BaseModule):
def __init__(self,
image_size,
in_channels=3,
out_channels=None,
base_channels=128,
resblocks_per_downsample=3,
num_timesteps=1000,
Expand Down Expand Up @@ -886,8 +890,9 @@ def __init__(self,
self.in_channels = in_channels

# double output_channels to output mean and var at same time
out_channels = in_channels if 'FIXED' in self.var_mode.upper() \
else 2 * in_channels
if out_channels is None:
out_channels = in_channels if 'FIXED' in self.var_mode.upper() \
else 2 * in_channels
self.out_channels = out_channels

# check type of image_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def infer(self,

# 6. Prepare latent variables
if hasattr(self.unet, 'module'):
num_channels_latents = self.vae.module.in_channels
num_channels_unet = self.unet.module.config.in_channels
num_channels_latents = self.vae.module.latent_channels
num_channels_unet = self.unet.module.in_channels
else:
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
num_channels_latents = self.vae.latent_channels
num_channels_unet = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand Down Expand Up @@ -175,11 +175,11 @@ def infer(self,
num_channels_masked_image = masked_image_latents.shape[1]
total_channels = num_channels_latents + \
num_channels_masked_image + num_channels_mask
if total_channels != self.unet.config.in_channels:
if total_channels != self.unet.in_channels:
raise ValueError(
'Incorrect configuration settings! The config of '
f'`pipeline.unet`: {self.unet.config} expects'
f' {self.unet.config.in_channels} but received '
f' {self.unet.in_channels} but received '
f'`num_channels_latents`: {num_channels_latents} +'
f' `num_channels_mask`: {num_channels_mask} + '
'`num_channels_masked_image`: '
Expand Down
1 change: 1 addition & 0 deletions mmagic/models/editors/stable_diffusion/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ def __init__(
super().__init__()

self.block_out_channels = block_out_channels
self.latent_channels = latent_channels

# pass init params to Encoder
self.encoder = Encoder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cross_attention_dim=768,
num_heads=2,
in_channels=9,
out_channels=4,
layers_per_block=1,
down_block_types=['CrossAttnDownBlock2D', 'DownBlock2D'],
up_block_types=['UpBlock2D', 'CrossAttnUpBlock2D'],
Expand Down Expand Up @@ -100,9 +101,18 @@ def test_stable_diffusion():
StableDiffuser = MODELS.build(Config(model))
StableDiffuser.tokenizer = dummy_tokenizer()
StableDiffuser.text_encoder = dummy_text_encoder()
config = getattr(StableDiffuser.vae, 'config', None)
if config is None:

class DummyConfig:
pass

config = DummyConfig()
setattr(config, 'scaling_factor', 1.2)
setattr(StableDiffuser.vae, 'config', config)

image = torch.clip(torch.randn((1, 3, 64, 64)), -1, 1)
mask = torch.clip(torch.randn((1, 3, 64, 64)), 0, 1)
mask = torch.clip(torch.randn((1, 1, 64, 64)), 0, 1)

with pytest.raises(Exception):
StableDiffuser.infer('temp', image, mask, height=31, width=31)
Expand Down

0 comments on commit 78316ce

Please sign in to comment.