From 8f726e052c2e6d33e9e828b8de1c92cbb0afa6ae Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Thu, 15 Feb 2024 17:33:24 -0800 Subject: [PATCH] implement MaMMUT (#520) Summary: Implement MaMMUT, mostly based on current CoCa code as well as https://github.com/lucidrains/MaMMUT-pytorch. Differential Revision: D52823194 Privacy Context Container: 303860477774201 --- torchmultimodal/modules/layers/transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index e2eea49e..f377879d 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -553,6 +553,8 @@ class TransformerDecoder(nn.Module): If None, K and V are assumed to have dimension d_model. Defaults to None. final_layer_norm_eps (Optional[float]): epsilon used in final layer norm. Defaults to None (no final layer norm). + cross_attention_interval: interval layers to apply cross attention. Not used if + use_cross_attention = False """ def __init__( @@ -568,6 +570,7 @@ def __init__( use_cross_attention: bool = True, dim_kv: Optional[int] = None, final_layer_norm_eps: Optional[float] = None, + cross_attention_interval: int = 1, ): super().__init__() self.layer = nn.ModuleList( @@ -580,7 +583,7 @@ def __init__( activation, layer_norm_eps, norm_first, - use_cross_attention, + use_cross_attention and (i % cross_attention_interval == 0), dim_kv, ) for i in range(n_layer)