diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 47f42d57..06237474 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -554,6 +554,8 @@ class TransformerDecoder(nn.Module): If None, K and V are assumed to have dimension d_model. Defaults to None. final_layer_norm_eps (Optional[float]): epsilon used in final layer norm. Defaults to None (no final layer norm). + cross_attention_interval: interval layers to apply cross attention. Not used if + use_cross_attention = False """ def __init__( @@ -569,6 +571,7 @@ def __init__( use_cross_attention: bool = True, dim_kv: Optional[int] = None, final_layer_norm_eps: Optional[float] = None, + cross_attention_interval: int = 1, ): super().__init__() self.layer = nn.ModuleList( @@ -581,7 +584,7 @@ def __init__( activation, layer_norm_eps, norm_first, - use_cross_attention, + use_cross_attention and (i % cross_attention_interval == 0), dim_kv, ) for i in range(n_layer)