diff --git a/tests/modules/layers/test_multi_head_attention.py b/tests/modules/layers/test_multi_head_attention.py index 486b77ef..62be41bc 100644 --- a/tests/modules/layers/test_multi_head_attention.py +++ b/tests/modules/layers/test_multi_head_attention.py @@ -70,6 +70,10 @@ def dim_kv(self): def q(self): return torch.Tensor([[[1, 2, 3, 1], [4, 3, 2, 1], [1, 1, 1, 1]]]) + @pytest.fixture + def kv(self): + return torch.Tensor([[[3, 2], [1, 1]]]) + @pytest.fixture def current_key_value(self): return torch.Tensor( @@ -106,6 +110,13 @@ def multi_head_cross_attn(self, dim_q, dim_kv): mha.eval() return mha + @pytest.fixture + def multi_head_cross_attn_without_bias(self, dim_q, dim_kv): + mha = MultiHeadAttentionWithCache(dim_q, dim_kv, num_heads=2, add_bias=False) + init_weights_with_constant(mha) + mha.eval() + return mha + def test_multi_head_self_attention_use_cache( self, multi_head_self_attn_use_cache, @@ -136,8 +147,7 @@ def test_multi_head_self_attention_use_cache( torch.cat([past_key_value, current_key_value], dim=2), ) - def test_multi_head_cross_attention(self, multi_head_cross_attn, q): - kv = torch.Tensor([[[3, 2], [1, 1]]]) + def test_multi_head_cross_attention(self, multi_head_cross_attn, q, kv): actual = multi_head_cross_attn(q, kv, kv) expected = torch.tensor( [ @@ -150,6 +160,21 @@ def test_multi_head_cross_attention(self, multi_head_cross_attn, q): ) assert_expected(actual, expected, rtol=0, atol=1e-4) + def test_multi_head_cross_attention_without_bias( + self, multi_head_cross_attn_without_bias, q, kv + ): + actual = multi_head_cross_attn_without_bias(q, kv, kv) + expected = torch.tensor( + [ + [ + [21.0, 21.0, 21.0, 21.0], + [21.0, 21.0, 21.0, 21.0], + [21.0, 21.0, 21.0, 21.0], + ], + ] + ) + assert_expected(actual, expected, rtol=0, atol=1e-4) + def test_scripting( self, multi_head_self_attn_use_cache, diff --git a/torchmultimodal/modules/layers/multi_head_attention.py b/torchmultimodal/modules/layers/multi_head_attention.py index 3da103d3..47b38969 100644 --- a/torchmultimodal/modules/layers/multi_head_attention.py +++ b/torchmultimodal/modules/layers/multi_head_attention.py @@ -89,6 +89,8 @@ class MultiHeadAttentionWithCache(nn.Module): same as dim_q for SA; equals to encoder dimension for cross-attention num_heads (int): number of attention heads dropout (float): dropout rate + add_bias (bool): if true, adds a learnable bias to query, key, value. + Defaults to True. """ def __init__( @@ -97,12 +99,13 @@ def __init__( dim_kv: int, num_heads: int, dropout: float = 0.0, + add_bias: bool = True, ) -> None: super().__init__() self.num_heads = num_heads - self.q_proj = nn.Linear(dim_q, dim_q) - self.k_proj = nn.Linear(dim_kv, dim_q) - self.v_proj = nn.Linear(dim_kv, dim_q) + self.q_proj = nn.Linear(dim_q, dim_q, bias=add_bias) + self.k_proj = nn.Linear(dim_kv, dim_q, bias=add_bias) + self.v_proj = nn.Linear(dim_kv, dim_q, bias=add_bias) self.output_proj = nn.Linear(dim_q, dim_q) self.dropout = dropout