Skip to content

Commit

Permalink
Fix attn_mask shape in MHA docstrings
Browse files Browse the repository at this point in the history
Summary: n/a

Reviewed By: ankitade, pikapecan

Differential Revision: D47998161

fbshipit-source-id: b2c1f48f6215a79e32ef38d83b971ff6f730d45c
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Aug 2, 2023
1 parent 1aa2ed2 commit 4e719a4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchmultimodal/modules/layers/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(
"""
Args:
query (Tensor): input query of shape bsz x seq_len x embed_dim
attn_mask (optional Tensor): attention mask of shape bsz x seq_len x seq_len. Two types of masks are supported.
attn_mask (optional Tensor): attention mask of shape bsz x num_heads x seq_len x seq_len. Two types of masks are supported.
A boolean mask where a value of True indicates that the element should take part in attention.
A float mask of the same type as query that is added to the attention score.
is_causal (bool): If true, does causal attention masking. attn_mask should be set to None if this is set to True
Expand Down Expand Up @@ -124,7 +124,7 @@ def forward(
query (Tensor): input query of shape bsz x target_seq_len x embed_dim
key (Tensor): key of shape bsz x source_seq_len x embed_dim
value (Tensor): value of shape bsz x source_seq_len x embed_dim
attn_mask (optional Tensor): Attention mask of shape bsz x target_seq_len x source_seq_len.
attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len.
Two types of masks are supported. A boolean mask where a value of True
indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
Expand Down

0 comments on commit 4e719a4

Please sign in to comment.