diff --git a/tests/models/blip2/__init__.py b/tests/models/blip2/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/models/blip2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/models/blip2/test_qformer_layers.py b/tests/models/blip2/test_qformer_layers.py new file mode 100644 index 00000000..adbb0e01 --- /dev/null +++ b/tests/models/blip2/test_qformer_layers.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torch import nn +from torchmultimodal.models.blip2.qformer_layers import ( + QformerEmbedding, + QformerEncoder, + QformerLayer, +) + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestQformerWithMHA: + @pytest.fixture + def dim_q(self): + return 4 + + @pytest.fixture + def dim_kv(self): + return 2 + + @pytest.fixture + def dim_feedforward(self): + return 6 + + @pytest.fixture + def cross_attention_freq(self): + return 2 + + @pytest.fixture + def num_hidden_layers(self): + return 2 + + @pytest.fixture + def num_heads(self): + return 2 + + @pytest.fixture() + def input_ids(self): + return torch.LongTensor([[0, 1], [2, 3]]) + + @pytest.fixture() + def query_embeddings(self): + return torch.Tensor( + [ + [ + [0.6424, 0.6182, 0.5110, 0.7867], + [0.3907, 0.2057, 0.6909, 0.6334], + ], + [ + [0.6904, 0.4445, 0.4336, 0.4603], + [0.6318, 0.1163, 0.0340, 0.6871], + ], + ] + ) + + @pytest.fixture + def q(self): + return torch.Tensor([[[1, 2, 3, 1], [4, 3, 2, 1], [1, 1, 1, 1]]]) + + @pytest.fixture + def kv(self): + return torch.Tensor([[[3, 2], [1, 1]]]) + + @pytest.fixture + def current_key_value(self): + return torch.Tensor( + [ + [ + [[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]], + [[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]], + ] + ] + ) + + @pytest.fixture + def past_key_value(self): + return torch.Tensor( + [ + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ] + ] + ) + + @pytest.fixture + def past_key_values(self, past_key_value, num_hidden_layers): + past_key_values = [] + for i in range(num_hidden_layers): + past_key_values.append((past_key_value, past_key_value)) + return past_key_values + + @pytest.fixture + def qformer_layer_self_attention_only(self, dim_q, dim_feedforward, num_heads): + qformer_layer = QformerLayer( + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + has_cross_attention=False, + ) + init_weights_with_constant(qformer_layer) + qformer_layer.eval() + return qformer_layer + + @pytest.fixture + def qformer_layer_with_cross_attention( + self, + dim_q, + dim_kv, + dim_feedforward, + num_heads, + ): + qformer_layer = QformerLayer( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + activation=nn.ReLU, + has_cross_attention=True, + ) + init_weights_with_constant(qformer_layer) + # modify query feedforward params to test cross attention case with different query lengths + init_weights_with_constant(qformer_layer.feedforward_query, 2.0) + init_weights_with_constant(qformer_layer.feedforward_layernorm_query, 2.0) + qformer_layer.eval() + return qformer_layer + + @pytest.fixture + def qformer_encoder( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + ): + qformer_encoder = QformerEncoder( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + cross_attention_freq=cross_attention_freq, + num_hidden_layers=num_hidden_layers, + ) + init_weights_with_constant(qformer_encoder) + qformer_encoder.eval() + return qformer_encoder + + def test_qformer_layer_self_attention_only( + self, qformer_layer_self_attention_only, current_key_value, past_key_value, q + ): + actual = qformer_layer_self_attention_only( + q, past_key_value=(past_key_value, past_key_value), use_cache=True + ) + expected = torch.Tensor( + [ + [ + [0.0955, 1.3015, 2.5076, 0.0955], + [2.3416, 1.4472, 0.5528, -0.3416], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_layer_with_cross_attention_only_query( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + # test with query length < attn_residual.shape[1] + actual = qformer_layer_with_cross_attention( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=2, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_layer_with_cross_attention_query_and_text( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + # test with query length >= attn_residual.shape[1] + actual = qformer_layer_with_cross_attention( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=3, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [2.0000, 2.0000, 2.0000, 2.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_encoder( + self, + qformer_encoder, + past_key_values, + current_key_value, + past_key_value, + q, + kv, + ): + actual = qformer_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + expected_hidden_state = torch.Tensor( + [ + [ + [0.0955, 1.3015, 2.5076, 0.0955], + [2.3416, 1.4472, 0.5528, -0.3416], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + expected_key_value = torch.Tensor( + [ + [ + [[5.0, 5.0], [5.0, 5.0], [5.0, 5.0]], + [[5.0, 5.0], [5.0, 5.0], [5.0, 5.0]], + ] + ] + ) + assert_expected(actual[0], expected_hidden_state, rtol=0, atol=1e-4) + assert_expected( + actual[1][0][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][0][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1][0], + torch.cat([past_key_value, expected_key_value], dim=2), + ) + assert_expected( + actual[1][1][1], + torch.cat([past_key_value, expected_key_value], dim=2), + ) + + def test_layer_scripting( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + scripted_model = torch.jit.script(qformer_layer_with_cross_attention) + actual = scripted_model( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=3, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [2.0000, 2.0000, 2.0000, 2.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_encoder_scripting( + self, + qformer_encoder, + past_key_values, + current_key_value, + past_key_value, + q, + kv, + ): + scripted_encoder = torch.jit.script(qformer_encoder) + actual = scripted_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + expected = qformer_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + assert_expected(actual[0], expected[0]) + assert_expected(actual[1], expected[1]) + assert len(actual) == len(expected) + + @pytest.fixture + def qformer_emb(self, dim_q): + return QformerEmbedding( + embedding_dim=dim_q, + max_position_embeddings=512, + vocab_size=20, + ) + + def test_qformer_embedding(self, input_ids, query_embeddings, qformer_emb): + actual = qformer_emb( + input_ids=input_ids, + query_embeddings=query_embeddings, + ) + expected_value = torch.Tensor( + [ + [ + [0.0287, -0.2175, -1.3081, 1.4969], + [-0.4602, -1.4116, 1.0838, 0.7880], + [-0.0600, 1.3838, -1.4382, 0.1144], + [1.1554, 0.0435, 0.3865, -1.5855], + ], + [ + [1.7251, -0.5904, -0.6931, -0.4416], + [0.8989, -0.8530, -1.1327, 1.0868], + [0.8951, -1.1037, -0.8854, 1.0940], + [-0.0748, -0.2439, 1.5529, -1.2342], + ], + ] + ) + # expected dim [bsz, num_token, embed_dim] + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_empty_input_ids( + self, + query_embeddings, + qformer_emb, + ): + actual = qformer_emb( + query_embeddings=query_embeddings, + ) + expected_value = torch.Tensor( + [ + [ + [0.0287, -0.2175, -1.3081, 1.4969], + [-0.4602, -1.4116, 1.0838, 0.7880], + ], + [ + [1.7251, -0.5904, -0.6931, -0.4416], + [0.8989, -0.8530, -1.1327, 1.0868], + ], + ] + ) + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_empty_query_embs( + self, + input_ids, + qformer_emb, + ): + actual = qformer_emb( + input_ids=input_ids, + ) + expected_value = torch.Tensor( + [ + [ + [-0.0600, 1.3838, -1.4382, 0.1144], + [1.1554, 0.0435, 0.3865, -1.5855], + ], + [ + [0.8951, -1.1037, -0.8854, 1.0940], + [-0.0748, -0.2439, 1.5529, -1.2342], + ], + ] + ) + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_invalid_input( + self, + qformer_emb, + ): + with pytest.raises(ValueError): + qformer_emb() + + def test_embedding_scripting(self, input_ids, qformer_emb, query_embeddings): + scripted_emb = torch.jit.script(qformer_emb) + actual = scripted_emb(input_ids=input_ids, query_embeddings=query_embeddings) + assert_expected( + actual, qformer_emb(input_ids=input_ids, query_embeddings=query_embeddings) + ) diff --git a/tests/models/blip2/test_qformer_utils.py b/tests/models/blip2/test_qformer_utils.py new file mode 100644 index 00000000..0fb07445 --- /dev/null +++ b/tests/models/blip2/test_qformer_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import assert_expected +from torch import Tensor +from torchmultimodal.models.blip2.qformer_utils import get_causal_mask + + +class TestExtendedAttnMaskForDecoder: + @pytest.fixture + def attention_mask(self): + return Tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]) + + @pytest.fixture + def input_shape(self): + return (2, 2) + + def test_extended_attention_mask(self, attention_mask): + actual_mask = get_causal_mask(attention_mask, attention_mask.shape) + expected = Tensor( + [ + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + ] + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) + + def test_extended_attention_mask_diff_input_size(self, attention_mask, input_shape): + actual_mask = get_causal_mask( + attention_mask, + input_shape, + ) + expected = Tensor( + Tensor( + [ + [[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + ] + ) + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) + + def test_extended_attention_mask_with_query_embs(self, attention_mask, input_shape): + actual_mask = get_causal_mask(attention_mask, input_shape, has_query=True) + expected = Tensor( + Tensor( + [ + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + ] + ) + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) diff --git a/torchmultimodal/models/blip2/__init__.py b/torchmultimodal/models/blip2/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/models/blip2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/blip2/qformer_layers.py b/torchmultimodal/models/blip2/qformer_layers.py new file mode 100644 index 00000000..94e1d30c --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_layers.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +import torch + +from torch import nn, Tensor + +from torchmultimodal.modules.layers.mlp import MLP +from torchmultimodal.modules.layers.multi_head_attention import ( + MHAWithCacheOutput, + MultiHeadAttentionWithCache, +) +from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm + + +class QformerLayer(nn.Module): + """ + Qformer layer module. + + This module is designed with a self-attention (SA) block and optionally includes a cross-attention (CA) block for queries. + The inputs for this module, referred to as hidden_states, can consist of either a query, text, or a combination of both. + Cross-attention is exclusively activated for queries (query_length > 0) with encoder_hidden_states derived from image inputs. + + The feedforward(ff) block will project the hidden states output by the layer before, + query output and text output are concatenated as overall output after separated handling for CA and ff. + + Args: + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + has_cross_attention (bool): whether a cross-attention layer is included + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA. + + """ + + def __init__( + self, + dim_q: int, + dim_feedforward: int, + num_heads: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + has_cross_attention: bool = False, + dim_kv: Optional[int] = None, + ): + super().__init__() + self.self_attention = MultiHeadAttentionWithCache( + dim_q, dim_q, num_heads, attn_dropout + ) + self.self_attn_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.has_cross_attention = has_cross_attention + self.cross_attention: Optional[MultiHeadAttentionWithCache] = None + + if has_cross_attention: + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder, key and value caching should be disabled. + if dim_kv is None: + raise ValueError( + "key and value dim should be provided for cross attention." + ) + self.cross_attention = MultiHeadAttentionWithCache( + dim_q=dim_q, dim_kv=dim_kv, num_heads=num_heads, dropout=attn_dropout + ) + self.cross_attn_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.cross_attn_dropout = nn.Dropout(dropout) + + # feedforward block + self.feedforward = MLP( + dim_q, dim_q, dim_feedforward, dropout=0.0, activation=activation + ) + self.feedforward_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.feedforward_dropout = nn.Dropout(dropout) + + # query feedforward block + self.feedforward_query = MLP( + dim_q, dim_q, dim_feedforward, dropout=0.0, activation=activation + ) + self.feedforward_layernorm_query = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.feedforward_dropout_query = nn.Dropout(dropout) + + def _self_attention_block( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + x = hidden_states + attn_output = self.self_attention( + x, + x, + x, + attn_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + present_key_value: Optional[Tuple[Tensor, Tensor]] = None + if use_cache: + assert isinstance(attn_output, MHAWithCacheOutput) + attn_output_value = attn_output.attn_output + present_key_value = attn_output.past_key_value + else: + assert isinstance(attn_output, Tensor) + attn_output_value = attn_output + attn_output = self.dropout(attn_output_value) + + attn_residual = attn_output + x + attn_residual = self.self_attn_layernorm(attn_residual) + return attn_residual, present_key_value + + def _cross_attention_block( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + ) -> Tensor: + x = hidden_states + assert self.cross_attention is not None + # turn off cache for cross attention + cross_attn_output = self.cross_attention( + query=x, + key=encoder_hidden_states, + value=encoder_hidden_states, + use_cache=False, + ) + + if not torch.jit.isinstance(cross_attn_output, Tensor): + raise ValueError("cross-attention output must be Tensor.") + cross_attn_output = self.cross_attn_dropout(cross_attn_output) + cross_attn_residual = cross_attn_output + x + cross_attn_residual = self.cross_attn_layernorm(cross_attn_residual) + return cross_attn_residual + + def _feedforward_block(self, hidden_states: Tensor) -> Tensor: + h = self.feedforward(hidden_states) + h = self.feedforward_dropout(h) + h = self.feedforward_layernorm(h + hidden_states) + return h + + def _feedforward_query_block(self, hidden_states: Tensor) -> Tensor: + h = self.feedforward_query(hidden_states) + h = self.feedforward_dropout_query(h) + h = self.feedforward_layernorm_query(h + hidden_states) + return h + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + query_length: int = 0, + use_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + """ + Inputs: + hidden_states (Tensor): input query of shape bsz x seq_len x embed_dim + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + attention_mask (Optional[Tensor]): attention mask, supported mask type is described in MultiHeadAttentionWithCache class + past_key_value (Optional[Tuple[Tensor, Tensor]]): cached key/value tuple for self-attention + query_length (Optional[int]): length of query embedding, used as condition + to determine query attention output and check text existance. + use_cache (bool): whether to use cache for key and value tensors + + Return: + A tuple includes: + layer_output (Tensor): layer output of shape bsz x seq_len x embed_dim + present_key_value (Optional[Tuple[Tensor, Tensor]]): key/value tuple for self-attention + """ + if past_key_value is not None and len(past_key_value) != 2: + raise ValueError( + "past_key_value should be 2-element tuple to represent self-attention cached key/values." + ) + attn_residual, present_key_value = self._self_attention_block( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + + if query_length > 0: + query_attn_output = attn_residual[:, :query_length, :] + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError( + "encoder_hidden_states must be given for cross-attention layers" + ) + cross_attn_output = self._cross_attention_block( + hidden_states=query_attn_output, + encoder_hidden_states=encoder_hidden_states, + ) + query_attn_output = cross_attn_output + + # add query feedforward block for self-attention or cross-attention + layer_output = self._feedforward_query_block(query_attn_output) + + # handle text input if present + if attn_residual.shape[1] > query_length: + layer_output_text = self._feedforward_block( + attn_residual[:, query_length:, :] + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + + else: + layer_output = self._feedforward_block(attn_residual) + + return (layer_output, present_key_value) + + +class QformerEncoder(nn.Module): + """ + Qformer encoder module including multiple Qformer layers. + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2. + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA. + + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + cross_attention_freq: int = 2, + dim_kv: Optional[int] = None, + ): + super().__init__() + layers = [] + for i in range(num_hidden_layers): + has_cross_attention = i % cross_attention_freq == 0 + layers.append( + QformerLayer( + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + layer_norm_eps=layer_norm_eps, + activation=activation, + has_cross_attention=has_cross_attention, + dim_kv=dim_kv, + ) + ) + self.layers = nn.ModuleList(layers) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + query_length: int = 0, + use_cache: bool = False, + ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + """ + Inputs: + hidden_states (Tensor): input query of shape bsz x seq_len x embed_dim + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + attention_mask (Optional[Tensor]): attention mask, supported mask type is described in MultiHeadAttentionWithCache class + past_key_values (Optional[List[Tuple[Tensor, Tensor]]]): cached key/value tuple for self-attention + query_length (int): the length of input query, used for cross-attention + use_cache (bool): whether to use cache for key and value tensors + + Return: + A tuple includes: + the last hidden state: Tensor of shape bsz x seq_len x embed_dim + past_key_values (List[Optional[Tuple[Tensor, Tensor]]]]): cached key/values from Qformer layers + """ + current_key_values = torch.jit.annotate(List[Tuple[Tensor, Tensor]], []) + for i, layer_module in enumerate(self.layers): + past_key_value = past_key_values[i] if past_key_values is not None else None + hidden_states, current_key_value = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + query_length=query_length, + use_cache=use_cache, + ) + if use_cache: + assert isinstance(current_key_value, tuple) + current_key_values.append(current_key_value) + + return (hidden_states, current_key_values) + + +class QformerEmbedding(nn.Module): + """ + Qformer embedding module. + + Args: + embedding_dim (int): dim of embedding space + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + dropout (float): dropout probability after embedding layers and layernorm. + layer_norm_eps (float): the epsilon used by the layer normalization layers. + """ + + def __init__( + self, + embedding_dim: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + layer_norm_eps: float = 1e-12, + dropout: float = 0.0, + ): + super().__init__() + self.token_embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=pad_token_id + ) + self.position_embeddings = nn.Embedding(max_position_embeddings, embedding_dim) + self.layernorm = Fp32LayerNorm(embedding_dim, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(max_position_embeddings).expand((1, -1)) + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeddings: Optional[Tensor] = None, + past_seq_length: int = 0, + ) -> Tensor: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids + position_ids (Optional[Tensor]): batches of of 1D integer tensors used to identify each token's position, + if no position_ids is provided, the IDs are automatically created as absolute positional embeddings. + query_embeddings (Optional[Tensor]): query embeddings for QFormer + past_seq_length (Optional[int]): sequence length cached by past_key_values. + + Returns: + embeddings (Tensor): concatenated embeddings of shape (bsz, num tokens, embedding dim), concatenation is along + the token dimension. + """ + if input_ids is None and query_embeddings is None: + raise ValueError("Either input_ids or query_embeddings must be passed.") + + seq_length = input_ids.size(1) if input_ids is not None else 0 + + embeddings = query_embeddings + + if input_ids is not None: + if position_ids is None: + position_ids = self.position_ids[ + :, past_seq_length : seq_length + past_seq_length + ].clone() + word_embeddings = self.token_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids.long()) + embeddings = word_embeddings + position_embeddings + + if query_embeddings is not None: + embeddings = torch.cat((query_embeddings, embeddings), dim=1) + + assert isinstance(embeddings, Tensor) + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings diff --git a/torchmultimodal/models/blip2/qformer_utils.py b/torchmultimodal/models/blip2/qformer_utils.py new file mode 100644 index 00000000..3b6022f3 --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from torch import Tensor +from torchmultimodal.utils.attention import get_causal_attention_mask + + +def get_causal_mask( + attention_mask: Tensor, + input_shape: Tuple[int, int], + has_query: bool = False, +) -> Tensor: + """A causal mask in addition to the padding mask for Q-Former for generation task. + when input seq_len is shorter than attn_mask, increasing causal_mask by prefix_seq_len with 1s; + if query is available, apply causal self-attention mask to control query-text interaction; + + Arguments: + attention_mask (Tensor) is a binary mask with 1 for unmasked and 0 for masked positions. + Attention_mask has size of [batch_size, attn_seq_len]. attn_seq_len can be only seq_len for text_token + or query_len + seq_len. + input_shape (tuple[int, int]): indicates input shape of (batch_size, input_seq_len) from embedding output. + If query_emb is used, input_seq_len is query_len + seq_len. + Input shape can be different from attention_mask shape for image caption and text generation tasks. + has_query (bool) indicating whether query is available in qformer input. + + Returns: + causal_mask (Tensor): mask size of [bsz, attn_seq_len, attn_seq_len] with query, + [bsz, input_seq_len, attn_seq_len] without query + + """ + device = attention_mask.device + batch_size, seq_len = input_shape + causal_mask = get_causal_attention_mask(seq_len).to(device) + causal_mask = causal_mask.repeat(batch_size, 1).view(batch_size, seq_len, seq_len) + # compare seq_len in input and attention mask + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: + # if query is available, apply causal self-attention mask to control query-text interaction. + # Allow queries attending each other but not the text tokens. + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + dim=1, + ) # mask size [bsz, attn_seq_len, input_seq_len] + # increase causal_mask by prefix_seq_len with 1s to attend self-attention + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + dim=-1, + ) # size of [bsz, attn_seq_len, attn_seq_len] with query, [bsz, input_seq_len, attn_seq_len] without query + return causal_mask