From 3b6375f3c09b16215d3ef2e19dbacbbc56567bd5 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Mon, 12 Feb 2024 16:08:32 -0800 Subject: [PATCH] create configuration file for MaMMUT training (#521) Summary: Mostly based on original coca and https://github.com/lucidrains/MaMMUT-pytorch Update the logics of loading checkpoint for MaMMUT text decoder as well. Differential Revision: D52891614 Privacy Context Container: 303860477774201 --- torchmultimodal/modules/layers/transformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 06237474..3330eabf 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -413,10 +413,7 @@ def _forward_prenorm( self_attn_output = attn_output + hidden_states # Optional cross-attention - if self.use_cross_attention: - assert ( - encoder_hidden_states is not None - ), "encoder_hidden_states must be provided for cross attention" + if self.use_cross_attention and encoder_hidden_states is not None: assert hasattr( self, "cross_attention_layernorm" ), "Cross-attention layernorm not initialized"