Skip to content

Commit

Permalink
implement MaMMUT (facebookresearch#520)
Browse files Browse the repository at this point in the history
Summary:

Implement MaMMUT, mostly based on current CoCa code as well as https://github.com/lucidrains/MaMMUT-pytorch.

Reviewed By: ebsmothers

Differential Revision: D52823194
  • Loading branch information
zhangtemplar authored and facebook-github-bot committed Feb 16, 2024
1 parent 1cccc58 commit a204344
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchmultimodal/modules/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ class TransformerDecoder(nn.Module):
If None, K and V are assumed to have dimension d_model. Defaults to None.
final_layer_norm_eps (Optional[float]): epsilon used in final layer norm.
Defaults to None (no final layer norm).
cross_attention_interval: interval layers to apply cross attention. Not used if
use_cross_attention = False
"""

def __init__(
Expand All @@ -568,6 +570,7 @@ def __init__(
use_cross_attention: bool = True,
dim_kv: Optional[int] = None,
final_layer_norm_eps: Optional[float] = None,
cross_attention_interval: int = 1,
):
super().__init__()
self.layer = nn.ModuleList(
Expand All @@ -580,7 +583,7 @@ def __init__(
activation,
layer_norm_eps,
norm_first,
use_cross_attention,
use_cross_attention and (i % cross_attention_interval == 0),
dim_kv,
)
for i in range(n_layer)
Expand Down

0 comments on commit a204344

Please sign in to comment.