Skip to content

Commit

Permalink
Update on "[diffusion_labs] Move attention and res blocks under modules"
Browse files Browse the repository at this point in the history
Differential Revision: [D50288362](https://our.internmc.facebook.com/intern/diff/D50288362)

[ghstack-poisoned]
  • Loading branch information
ebsmothers committed Oct 16, 2023
2 parents 6f0a65f + d57bdfa commit 81274af
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 53 deletions.
112 changes: 61 additions & 51 deletions torchmultimodal/diffusion_labs/models/adm_unet/adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
SinusoidalPositionEmbeddings,
)

DEFAULT_EMBED_NAME = "clip_image"


class ADMUNet(nn.Module):
"""Ablated Diffusion Model as described in "Diffusion Models Beat GANs on Image Synthesis" (https://arxiv.org/abs/2105.05233)
Expand Down Expand Up @@ -95,15 +93,19 @@ class ADMUNet(nn.Module):
Expected shape of tensors are [b, c], where c is the embedding dim of the Tensor.
"""

DEFAULT_EMBED_NAME = "clip_image"

def __init__(
self,
*,
channels_per_layer: List[int],
num_resize: int,
num_res_per_layer: int,
use_attention_for_layer: List[bool],
dim_res_cond: int,
dim_attn_cond: Optional[int] = None,
embed_dim: Optional[int] = None,
embed_name: str = DEFAULT_EMBED_NAME,
in_channels: int = 3,
out_channels: int = 3,
time_embed_dim: Optional[int] = None,
Expand All @@ -124,6 +126,7 @@ def __init__(
)
else:
self.timestep_encoder = timestep_encoder
self.embed_name = embed_name
if res_cond_proj is None and embed_dim is not None and dim_res_cond is not None:
res_cond_proj = self._create_res_cond_proj(embed_dim, dim_res_cond)
else:
Expand All @@ -136,6 +139,7 @@ def __init__(
attn_cond_proj = self._create_attn_cond_proj(embed_dim, dim_attn_cond)
else:
self.attn_cond_proj = attn_cond_proj

self.predict_variance_value = predict_variance_value
self.variance_value_transform = variance_value_transform or nn.Identity()

Expand Down Expand Up @@ -174,14 +178,14 @@ def _create_timestep_encoder(
def _create_res_cond_proj(
self, embed_dim: int, cond_embed_dim: int
) -> nn.ModuleDict:
return nn.ModuleDict({DEFAULT_EMBED_NAME: nn.Linear(embed_dim, cond_embed_dim)})
return nn.ModuleDict({self.embed_name: nn.Linear(embed_dim, cond_embed_dim)})

def _create_attn_cond_proj(
self, embed_dim: int, cond_embed_dim: int
) -> nn.ModuleDict:
return nn.ModuleDict(
{
DEFAULT_EMBED_NAME: nn.Sequential(
self.embed_name: nn.Sequential(
nn.Linear(
embed_dim, cond_embed_dim * 4
), # four tokens of context as per paper ref
Expand Down Expand Up @@ -327,47 +331,6 @@ def _create_upsampling_decoder(self, down_channels: List[int]) -> nn.ModuleList:
net = nn.ModuleList(stacks + [out_conv])
return net

def forward(
self,
x: Tensor,
timestep: Tensor,
conditional_inputs: Optional[Dict[str, Tensor]] = None,
) -> DiffusionOutput:
(
res_conditional_embedding,
attn_conditional_embedding,
) = self._get_conditional_projections(timestep, conditional_inputs)

hidden_states = []
h = x
for block in self.down:
h = block(h, res_conditional_embedding, attn_conditional_embedding)
hidden_states.append(h)
h = self.bottleneck(h, res_conditional_embedding, attn_conditional_embedding)
for block in self.up:
if hidden_states:
h = torch.cat([h, hidden_states.pop()], dim=1)
h = block(h, res_conditional_embedding, attn_conditional_embedding)

# If model is predicting variance, then it should be configured to output double the channels as input
if self.predict_variance_value:
if h.shape[1] != x.shape[1] * 2:
raise ValueError(
f"unet is not configured to predict variance values. "
f"Expected output channel dim to be {x.shape[1] * 2}, got {h.shape[1]}"
)
# Split in half in channel dim
prediction, variance_value = torch.chunk(h, 2, dim=1)
variance_value = self.variance_value_transform(variance_value)
return DiffusionOutput(
prediction=prediction,
variance_value=variance_value,
)
else:
return DiffusionOutput(
prediction=h,
)

def _get_conditional_projections(
self,
timestep: Tensor,
Expand Down Expand Up @@ -408,6 +371,53 @@ def _get_conditional_projections(
attn_cond = torch.concat(attn_cond, dim=1) if attn_cond else None
return res_cond, attn_cond

def _get_variance_value(
self, x: Tensor, h: Tensor
) -> Tuple[Tensor, Optional[Tensor]]:
# If model is predicting variance, then it should be configured to output double the channels as input
if self.predict_variance_value:
if h.shape[1] != x.shape[1] * 2:
raise ValueError(
f"unet is not configured to predict variance values. "
f"Expected output channel dim to be {x.shape[1] * 2}, got {h.shape[1]}"
)
# Split in half in channel dim
prediction, variance_value = torch.chunk(h, 2, dim=1)
variance_value = self.variance_value_transform(variance_value)
else:
prediction = h
variance_value = None
return prediction, variance_value

def forward(
self,
x: Tensor,
timestep: Tensor,
conditional_inputs: Optional[Dict[str, Tensor]] = None,
) -> DiffusionOutput:
(
res_conditional_embedding,
attn_conditional_embedding,
) = self._get_conditional_projections(timestep, conditional_inputs)

hidden_states = []
h = x
for block in self.down:
h = block(h, res_conditional_embedding, attn_conditional_embedding)
hidden_states.append(h)
h = self.bottleneck(h, res_conditional_embedding, attn_conditional_embedding)
for block in self.up:
if hidden_states:
h = torch.cat([h, hidden_states.pop()], dim=1)
h = block(h, res_conditional_embedding, attn_conditional_embedding)

prediction, variance_value = self._get_variance_value(x, h)

return DiffusionOutput(
prediction=prediction,
variance_value=variance_value,
)


class ADMStackModuleType(Enum):
ResidualBlock = 0
Expand Down Expand Up @@ -514,16 +524,16 @@ def adm_unet(
# ADM args
time_embed_dim: int = 512,
cond_embed_dim: int = 2048,
clip_embed_dim: int = 768,
clip_embed_name: str = DEFAULT_EMBED_NAME,
embed_dim: int = 768,
embed_name: str = "clip_image",
predict_variance_value: bool = True,
# ADMUNet args
image_channels: int = 4,
depth: int = 512,
num_resize: int = 3,
num_res_per_layer: int = 3,
) -> nn.Module:
"""Constructs the DALLE-2 base conditional UNet
"""Constructs a conditional ADM U-Net
Consists of an ADM UNet diffusion model conditioned on CLIP image embeddings.
Expand All @@ -533,8 +543,8 @@ def adm_unet(
Args:
time_embed_dim (int): desired dimensionality of timestep embedding
cond_embed_dim (int): desired dimensionality of conditional input embeddings
clip_embed_dim (int): expected dimensionality of CLIP image embeddings
clip_embed_name (str): name of CLIP embedding conditional input
embed_dim (int): expected dimensionality of CLIP image embeddings
embed_name (str): name of CLIP embedding conditional input
predict_variance_value (bool): if True, will double UNet's output channel dim to predict variance values of
diffusion process
image_channels (int): channel dim of input images
Expand Down Expand Up @@ -562,7 +572,7 @@ def adm_unet(
use_attention_for_layer=use_attention_per_layer,
dim_res_cond=cond_embed_dim,
dim_attn_cond=cond_embed_dim,
embed_dim=clip_embed_dim,
embed_dim=embed_dim,
in_channels=in_channels,
out_channels=out_channels,
time_embed_dim=time_embed_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def dalle2_decoder(
diffusion_model = adm_unet(
time_embed_dim=time_embed_dim,
cond_embed_dim=cond_embed_dim,
clip_embed_dim=clip_embed_dim,
clip_embed_name=clip_embed_name,
embed_dim=clip_embed_dim,
embed_name=clip_embed_name,
predict_variance_value=predict_variance_value,
image_channels=image_channels,
depth=depth,
Expand Down

0 comments on commit 81274af

Please sign in to comment.