diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py index 74dc39c9..1ffe0f45 100644 --- a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py @@ -95,7 +95,7 @@ class ADMUNet(nn.Module): Expected shape of tensors are [b, c], where c is the embedding dim of the Tensor. """ - DEFAULT_EMBED_NAME = "clip_image" + DEFAULT_EMBED_NAME = "context" def __init__( self,