diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index f377879d..0f02c7dd 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -412,10 +412,7 @@ def _forward_prenorm( self_attn_output = attn_output + hidden_states # Optional cross-attention - if self.use_cross_attention: - assert ( - encoder_hidden_states is not None - ), "encoder_hidden_states must be provided for cross attention" + if self.use_cross_attention and encoder_hidden_states is not None: assert hasattr( self, "cross_attention_layernorm" ), "Cross-attention layernorm not initialized"