diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index e2eea49e..f377879d 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -553,6 +553,8 @@ class TransformerDecoder(nn.Module): If None, K and V are assumed to have dimension d_model. Defaults to None. final_layer_norm_eps (Optional[float]): epsilon used in final layer norm. Defaults to None (no final layer norm). + cross_attention_interval: interval layers to apply cross attention. Not used if + use_cross_attention = False """ def __init__( @@ -568,6 +570,7 @@ def __init__( use_cross_attention: bool = True, dim_kv: Optional[int] = None, final_layer_norm_eps: Optional[float] = None, + cross_attention_interval: int = 1, ): super().__init__() self.layer = nn.ModuleList( @@ -580,7 +583,7 @@ def __init__( activation, layer_norm_eps, norm_first, - use_cross_attention, + use_cross_attention and (i % cross_attention_interval == 0), dim_kv, ) for i in range(n_layer)