From 2cbab1f089ae84b557bf19bd91ff03d4efd4b1d9 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Fri, 16 Feb 2024 12:05:33 -0800 Subject: [PATCH] create configuration file for MaMMUT training (#521) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/521 Mostly based on original coca and https://github.com/lucidrains/MaMMUT-pytorch Update the logics of loading checkpoint for MaMMUT text decoder as well. Differential Revision: D52891614 Privacy Context Container: 303860477774201 fbshipit-source-id: 192a1826fd59a80bf99e8545408e19938069a599 --- torchmultimodal/modules/layers/transformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index f377879d..0f02c7dd 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -412,10 +412,7 @@ def _forward_prenorm( self_attn_output = attn_output + hidden_states # Optional cross-attention - if self.use_cross_attention: - assert ( - encoder_hidden_states is not None - ), "encoder_hidden_states must be provided for cross attention" + if self.use_cross_attention and encoder_hidden_states is not None: assert hasattr( self, "cross_attention_layernorm" ), "Cross-attention layernorm not initialized"