diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 06237474..3330eabf 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -413,10 +413,7 @@ def _forward_prenorm( self_attn_output = attn_output + hidden_states # Optional cross-attention - if self.use_cross_attention: - assert ( - encoder_hidden_states is not None - ), "encoder_hidden_states must be provided for cross attention" + if self.use_cross_attention and encoder_hidden_states is not None: assert hasattr( self, "cross_attention_layernorm" ), "Cross-attention layernorm not initialized"