From c03924ed0f20115b1c87af2ec386560cae69f643 Mon Sep 17 00:00:00 2001 From: pengchen18 <> Date: Wed, 11 Oct 2023 18:08:40 -0700 Subject: [PATCH 1/2] port qformer layers Differential Revision: D50137871 fbshipit-source-id: 71e8249fca342c3ab66257c94aeebcbde118d2cc --- tests/models/blip2/__init__.py | 5 + tests/models/blip2/test_qformer_layers.py | 452 ++++++++++++++++++ tests/models/blip2/test_qformer_utils.py | 77 +++ torchmultimodal/models/blip2/__init__.py | 5 + .../models/blip2/qformer_layers.py | 387 +++++++++++++++ torchmultimodal/models/blip2/qformer_utils.py | 71 +++ 6 files changed, 997 insertions(+) create mode 100644 tests/models/blip2/__init__.py create mode 100644 tests/models/blip2/test_qformer_layers.py create mode 100644 tests/models/blip2/test_qformer_utils.py create mode 100644 torchmultimodal/models/blip2/__init__.py create mode 100644 torchmultimodal/models/blip2/qformer_layers.py create mode 100644 torchmultimodal/models/blip2/qformer_utils.py diff --git a/tests/models/blip2/__init__.py b/tests/models/blip2/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/models/blip2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/models/blip2/test_qformer_layers.py b/tests/models/blip2/test_qformer_layers.py new file mode 100644 index 00000000..adbb0e01 --- /dev/null +++ b/tests/models/blip2/test_qformer_layers.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torch import nn +from torchmultimodal.models.blip2.qformer_layers import ( + QformerEmbedding, + QformerEncoder, + QformerLayer, +) + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestQformerWithMHA: + @pytest.fixture + def dim_q(self): + return 4 + + @pytest.fixture + def dim_kv(self): + return 2 + + @pytest.fixture + def dim_feedforward(self): + return 6 + + @pytest.fixture + def cross_attention_freq(self): + return 2 + + @pytest.fixture + def num_hidden_layers(self): + return 2 + + @pytest.fixture + def num_heads(self): + return 2 + + @pytest.fixture() + def input_ids(self): + return torch.LongTensor([[0, 1], [2, 3]]) + + @pytest.fixture() + def query_embeddings(self): + return torch.Tensor( + [ + [ + [0.6424, 0.6182, 0.5110, 0.7867], + [0.3907, 0.2057, 0.6909, 0.6334], + ], + [ + [0.6904, 0.4445, 0.4336, 0.4603], + [0.6318, 0.1163, 0.0340, 0.6871], + ], + ] + ) + + @pytest.fixture + def q(self): + return torch.Tensor([[[1, 2, 3, 1], [4, 3, 2, 1], [1, 1, 1, 1]]]) + + @pytest.fixture + def kv(self): + return torch.Tensor([[[3, 2], [1, 1]]]) + + @pytest.fixture + def current_key_value(self): + return torch.Tensor( + [ + [ + [[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]], + [[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]], + ] + ] + ) + + @pytest.fixture + def past_key_value(self): + return torch.Tensor( + [ + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ] + ] + ) + + @pytest.fixture + def past_key_values(self, past_key_value, num_hidden_layers): + past_key_values = [] + for i in range(num_hidden_layers): + past_key_values.append((past_key_value, past_key_value)) + return past_key_values + + @pytest.fixture + def qformer_layer_self_attention_only(self, dim_q, dim_feedforward, num_heads): + qformer_layer = QformerLayer( + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + has_cross_attention=False, + ) + init_weights_with_constant(qformer_layer) + qformer_layer.eval() + return qformer_layer + + @pytest.fixture + def qformer_layer_with_cross_attention( + self, + dim_q, + dim_kv, + dim_feedforward, + num_heads, + ): + qformer_layer = QformerLayer( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + activation=nn.ReLU, + has_cross_attention=True, + ) + init_weights_with_constant(qformer_layer) + # modify query feedforward params to test cross attention case with different query lengths + init_weights_with_constant(qformer_layer.feedforward_query, 2.0) + init_weights_with_constant(qformer_layer.feedforward_layernorm_query, 2.0) + qformer_layer.eval() + return qformer_layer + + @pytest.fixture + def qformer_encoder( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + ): + qformer_encoder = QformerEncoder( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + cross_attention_freq=cross_attention_freq, + num_hidden_layers=num_hidden_layers, + ) + init_weights_with_constant(qformer_encoder) + qformer_encoder.eval() + return qformer_encoder + + def test_qformer_layer_self_attention_only( + self, qformer_layer_self_attention_only, current_key_value, past_key_value, q + ): + actual = qformer_layer_self_attention_only( + q, past_key_value=(past_key_value, past_key_value), use_cache=True + ) + expected = torch.Tensor( + [ + [ + [0.0955, 1.3015, 2.5076, 0.0955], + [2.3416, 1.4472, 0.5528, -0.3416], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_layer_with_cross_attention_only_query( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + # test with query length < attn_residual.shape[1] + actual = qformer_layer_with_cross_attention( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=2, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_layer_with_cross_attention_query_and_text( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + # test with query length >= attn_residual.shape[1] + actual = qformer_layer_with_cross_attention( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=3, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [2.0000, 2.0000, 2.0000, 2.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_encoder( + self, + qformer_encoder, + past_key_values, + current_key_value, + past_key_value, + q, + kv, + ): + actual = qformer_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + expected_hidden_state = torch.Tensor( + [ + [ + [0.0955, 1.3015, 2.5076, 0.0955], + [2.3416, 1.4472, 0.5528, -0.3416], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + expected_key_value = torch.Tensor( + [ + [ + [[5.0, 5.0], [5.0, 5.0], [5.0, 5.0]], + [[5.0, 5.0], [5.0, 5.0], [5.0, 5.0]], + ] + ] + ) + assert_expected(actual[0], expected_hidden_state, rtol=0, atol=1e-4) + assert_expected( + actual[1][0][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][0][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1][0], + torch.cat([past_key_value, expected_key_value], dim=2), + ) + assert_expected( + actual[1][1][1], + torch.cat([past_key_value, expected_key_value], dim=2), + ) + + def test_layer_scripting( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + scripted_model = torch.jit.script(qformer_layer_with_cross_attention) + actual = scripted_model( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=3, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [2.0000, 2.0000, 2.0000, 2.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_encoder_scripting( + self, + qformer_encoder, + past_key_values, + current_key_value, + past_key_value, + q, + kv, + ): + scripted_encoder = torch.jit.script(qformer_encoder) + actual = scripted_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + expected = qformer_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + assert_expected(actual[0], expected[0]) + assert_expected(actual[1], expected[1]) + assert len(actual) == len(expected) + + @pytest.fixture + def qformer_emb(self, dim_q): + return QformerEmbedding( + embedding_dim=dim_q, + max_position_embeddings=512, + vocab_size=20, + ) + + def test_qformer_embedding(self, input_ids, query_embeddings, qformer_emb): + actual = qformer_emb( + input_ids=input_ids, + query_embeddings=query_embeddings, + ) + expected_value = torch.Tensor( + [ + [ + [0.0287, -0.2175, -1.3081, 1.4969], + [-0.4602, -1.4116, 1.0838, 0.7880], + [-0.0600, 1.3838, -1.4382, 0.1144], + [1.1554, 0.0435, 0.3865, -1.5855], + ], + [ + [1.7251, -0.5904, -0.6931, -0.4416], + [0.8989, -0.8530, -1.1327, 1.0868], + [0.8951, -1.1037, -0.8854, 1.0940], + [-0.0748, -0.2439, 1.5529, -1.2342], + ], + ] + ) + # expected dim [bsz, num_token, embed_dim] + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_empty_input_ids( + self, + query_embeddings, + qformer_emb, + ): + actual = qformer_emb( + query_embeddings=query_embeddings, + ) + expected_value = torch.Tensor( + [ + [ + [0.0287, -0.2175, -1.3081, 1.4969], + [-0.4602, -1.4116, 1.0838, 0.7880], + ], + [ + [1.7251, -0.5904, -0.6931, -0.4416], + [0.8989, -0.8530, -1.1327, 1.0868], + ], + ] + ) + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_empty_query_embs( + self, + input_ids, + qformer_emb, + ): + actual = qformer_emb( + input_ids=input_ids, + ) + expected_value = torch.Tensor( + [ + [ + [-0.0600, 1.3838, -1.4382, 0.1144], + [1.1554, 0.0435, 0.3865, -1.5855], + ], + [ + [0.8951, -1.1037, -0.8854, 1.0940], + [-0.0748, -0.2439, 1.5529, -1.2342], + ], + ] + ) + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_invalid_input( + self, + qformer_emb, + ): + with pytest.raises(ValueError): + qformer_emb() + + def test_embedding_scripting(self, input_ids, qformer_emb, query_embeddings): + scripted_emb = torch.jit.script(qformer_emb) + actual = scripted_emb(input_ids=input_ids, query_embeddings=query_embeddings) + assert_expected( + actual, qformer_emb(input_ids=input_ids, query_embeddings=query_embeddings) + ) diff --git a/tests/models/blip2/test_qformer_utils.py b/tests/models/blip2/test_qformer_utils.py new file mode 100644 index 00000000..0fb07445 --- /dev/null +++ b/tests/models/blip2/test_qformer_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import assert_expected +from torch import Tensor +from torchmultimodal.models.blip2.qformer_utils import get_causal_mask + + +class TestExtendedAttnMaskForDecoder: + @pytest.fixture + def attention_mask(self): + return Tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]) + + @pytest.fixture + def input_shape(self): + return (2, 2) + + def test_extended_attention_mask(self, attention_mask): + actual_mask = get_causal_mask(attention_mask, attention_mask.shape) + expected = Tensor( + [ + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + ] + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) + + def test_extended_attention_mask_diff_input_size(self, attention_mask, input_shape): + actual_mask = get_causal_mask( + attention_mask, + input_shape, + ) + expected = Tensor( + Tensor( + [ + [[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + ] + ) + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) + + def test_extended_attention_mask_with_query_embs(self, attention_mask, input_shape): + actual_mask = get_causal_mask(attention_mask, input_shape, has_query=True) + expected = Tensor( + Tensor( + [ + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + ] + ) + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) diff --git a/torchmultimodal/models/blip2/__init__.py b/torchmultimodal/models/blip2/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/models/blip2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/blip2/qformer_layers.py b/torchmultimodal/models/blip2/qformer_layers.py new file mode 100644 index 00000000..94e1d30c --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_layers.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +import torch + +from torch import nn, Tensor + +from torchmultimodal.modules.layers.mlp import MLP +from torchmultimodal.modules.layers.multi_head_attention import ( + MHAWithCacheOutput, + MultiHeadAttentionWithCache, +) +from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm + + +class QformerLayer(nn.Module): + """ + Qformer layer module. + + This module is designed with a self-attention (SA) block and optionally includes a cross-attention (CA) block for queries. + The inputs for this module, referred to as hidden_states, can consist of either a query, text, or a combination of both. + Cross-attention is exclusively activated for queries (query_length > 0) with encoder_hidden_states derived from image inputs. + + The feedforward(ff) block will project the hidden states output by the layer before, + query output and text output are concatenated as overall output after separated handling for CA and ff. + + Args: + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + has_cross_attention (bool): whether a cross-attention layer is included + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA. + + """ + + def __init__( + self, + dim_q: int, + dim_feedforward: int, + num_heads: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + has_cross_attention: bool = False, + dim_kv: Optional[int] = None, + ): + super().__init__() + self.self_attention = MultiHeadAttentionWithCache( + dim_q, dim_q, num_heads, attn_dropout + ) + self.self_attn_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.has_cross_attention = has_cross_attention + self.cross_attention: Optional[MultiHeadAttentionWithCache] = None + + if has_cross_attention: + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder, key and value caching should be disabled. + if dim_kv is None: + raise ValueError( + "key and value dim should be provided for cross attention." + ) + self.cross_attention = MultiHeadAttentionWithCache( + dim_q=dim_q, dim_kv=dim_kv, num_heads=num_heads, dropout=attn_dropout + ) + self.cross_attn_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.cross_attn_dropout = nn.Dropout(dropout) + + # feedforward block + self.feedforward = MLP( + dim_q, dim_q, dim_feedforward, dropout=0.0, activation=activation + ) + self.feedforward_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.feedforward_dropout = nn.Dropout(dropout) + + # query feedforward block + self.feedforward_query = MLP( + dim_q, dim_q, dim_feedforward, dropout=0.0, activation=activation + ) + self.feedforward_layernorm_query = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.feedforward_dropout_query = nn.Dropout(dropout) + + def _self_attention_block( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + x = hidden_states + attn_output = self.self_attention( + x, + x, + x, + attn_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + present_key_value: Optional[Tuple[Tensor, Tensor]] = None + if use_cache: + assert isinstance(attn_output, MHAWithCacheOutput) + attn_output_value = attn_output.attn_output + present_key_value = attn_output.past_key_value + else: + assert isinstance(attn_output, Tensor) + attn_output_value = attn_output + attn_output = self.dropout(attn_output_value) + + attn_residual = attn_output + x + attn_residual = self.self_attn_layernorm(attn_residual) + return attn_residual, present_key_value + + def _cross_attention_block( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + ) -> Tensor: + x = hidden_states + assert self.cross_attention is not None + # turn off cache for cross attention + cross_attn_output = self.cross_attention( + query=x, + key=encoder_hidden_states, + value=encoder_hidden_states, + use_cache=False, + ) + + if not torch.jit.isinstance(cross_attn_output, Tensor): + raise ValueError("cross-attention output must be Tensor.") + cross_attn_output = self.cross_attn_dropout(cross_attn_output) + cross_attn_residual = cross_attn_output + x + cross_attn_residual = self.cross_attn_layernorm(cross_attn_residual) + return cross_attn_residual + + def _feedforward_block(self, hidden_states: Tensor) -> Tensor: + h = self.feedforward(hidden_states) + h = self.feedforward_dropout(h) + h = self.feedforward_layernorm(h + hidden_states) + return h + + def _feedforward_query_block(self, hidden_states: Tensor) -> Tensor: + h = self.feedforward_query(hidden_states) + h = self.feedforward_dropout_query(h) + h = self.feedforward_layernorm_query(h + hidden_states) + return h + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + query_length: int = 0, + use_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + """ + Inputs: + hidden_states (Tensor): input query of shape bsz x seq_len x embed_dim + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + attention_mask (Optional[Tensor]): attention mask, supported mask type is described in MultiHeadAttentionWithCache class + past_key_value (Optional[Tuple[Tensor, Tensor]]): cached key/value tuple for self-attention + query_length (Optional[int]): length of query embedding, used as condition + to determine query attention output and check text existance. + use_cache (bool): whether to use cache for key and value tensors + + Return: + A tuple includes: + layer_output (Tensor): layer output of shape bsz x seq_len x embed_dim + present_key_value (Optional[Tuple[Tensor, Tensor]]): key/value tuple for self-attention + """ + if past_key_value is not None and len(past_key_value) != 2: + raise ValueError( + "past_key_value should be 2-element tuple to represent self-attention cached key/values." + ) + attn_residual, present_key_value = self._self_attention_block( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + + if query_length > 0: + query_attn_output = attn_residual[:, :query_length, :] + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError( + "encoder_hidden_states must be given for cross-attention layers" + ) + cross_attn_output = self._cross_attention_block( + hidden_states=query_attn_output, + encoder_hidden_states=encoder_hidden_states, + ) + query_attn_output = cross_attn_output + + # add query feedforward block for self-attention or cross-attention + layer_output = self._feedforward_query_block(query_attn_output) + + # handle text input if present + if attn_residual.shape[1] > query_length: + layer_output_text = self._feedforward_block( + attn_residual[:, query_length:, :] + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + + else: + layer_output = self._feedforward_block(attn_residual) + + return (layer_output, present_key_value) + + +class QformerEncoder(nn.Module): + """ + Qformer encoder module including multiple Qformer layers. + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2. + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA. + + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + cross_attention_freq: int = 2, + dim_kv: Optional[int] = None, + ): + super().__init__() + layers = [] + for i in range(num_hidden_layers): + has_cross_attention = i % cross_attention_freq == 0 + layers.append( + QformerLayer( + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + layer_norm_eps=layer_norm_eps, + activation=activation, + has_cross_attention=has_cross_attention, + dim_kv=dim_kv, + ) + ) + self.layers = nn.ModuleList(layers) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + query_length: int = 0, + use_cache: bool = False, + ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + """ + Inputs: + hidden_states (Tensor): input query of shape bsz x seq_len x embed_dim + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + attention_mask (Optional[Tensor]): attention mask, supported mask type is described in MultiHeadAttentionWithCache class + past_key_values (Optional[List[Tuple[Tensor, Tensor]]]): cached key/value tuple for self-attention + query_length (int): the length of input query, used for cross-attention + use_cache (bool): whether to use cache for key and value tensors + + Return: + A tuple includes: + the last hidden state: Tensor of shape bsz x seq_len x embed_dim + past_key_values (List[Optional[Tuple[Tensor, Tensor]]]]): cached key/values from Qformer layers + """ + current_key_values = torch.jit.annotate(List[Tuple[Tensor, Tensor]], []) + for i, layer_module in enumerate(self.layers): + past_key_value = past_key_values[i] if past_key_values is not None else None + hidden_states, current_key_value = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + query_length=query_length, + use_cache=use_cache, + ) + if use_cache: + assert isinstance(current_key_value, tuple) + current_key_values.append(current_key_value) + + return (hidden_states, current_key_values) + + +class QformerEmbedding(nn.Module): + """ + Qformer embedding module. + + Args: + embedding_dim (int): dim of embedding space + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + dropout (float): dropout probability after embedding layers and layernorm. + layer_norm_eps (float): the epsilon used by the layer normalization layers. + """ + + def __init__( + self, + embedding_dim: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + layer_norm_eps: float = 1e-12, + dropout: float = 0.0, + ): + super().__init__() + self.token_embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=pad_token_id + ) + self.position_embeddings = nn.Embedding(max_position_embeddings, embedding_dim) + self.layernorm = Fp32LayerNorm(embedding_dim, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(max_position_embeddings).expand((1, -1)) + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeddings: Optional[Tensor] = None, + past_seq_length: int = 0, + ) -> Tensor: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids + position_ids (Optional[Tensor]): batches of of 1D integer tensors used to identify each token's position, + if no position_ids is provided, the IDs are automatically created as absolute positional embeddings. + query_embeddings (Optional[Tensor]): query embeddings for QFormer + past_seq_length (Optional[int]): sequence length cached by past_key_values. + + Returns: + embeddings (Tensor): concatenated embeddings of shape (bsz, num tokens, embedding dim), concatenation is along + the token dimension. + """ + if input_ids is None and query_embeddings is None: + raise ValueError("Either input_ids or query_embeddings must be passed.") + + seq_length = input_ids.size(1) if input_ids is not None else 0 + + embeddings = query_embeddings + + if input_ids is not None: + if position_ids is None: + position_ids = self.position_ids[ + :, past_seq_length : seq_length + past_seq_length + ].clone() + word_embeddings = self.token_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids.long()) + embeddings = word_embeddings + position_embeddings + + if query_embeddings is not None: + embeddings = torch.cat((query_embeddings, embeddings), dim=1) + + assert isinstance(embeddings, Tensor) + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings diff --git a/torchmultimodal/models/blip2/qformer_utils.py b/torchmultimodal/models/blip2/qformer_utils.py new file mode 100644 index 00000000..3b6022f3 --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from torch import Tensor +from torchmultimodal.utils.attention import get_causal_attention_mask + + +def get_causal_mask( + attention_mask: Tensor, + input_shape: Tuple[int, int], + has_query: bool = False, +) -> Tensor: + """A causal mask in addition to the padding mask for Q-Former for generation task. + when input seq_len is shorter than attn_mask, increasing causal_mask by prefix_seq_len with 1s; + if query is available, apply causal self-attention mask to control query-text interaction; + + Arguments: + attention_mask (Tensor) is a binary mask with 1 for unmasked and 0 for masked positions. + Attention_mask has size of [batch_size, attn_seq_len]. attn_seq_len can be only seq_len for text_token + or query_len + seq_len. + input_shape (tuple[int, int]): indicates input shape of (batch_size, input_seq_len) from embedding output. + If query_emb is used, input_seq_len is query_len + seq_len. + Input shape can be different from attention_mask shape for image caption and text generation tasks. + has_query (bool) indicating whether query is available in qformer input. + + Returns: + causal_mask (Tensor): mask size of [bsz, attn_seq_len, attn_seq_len] with query, + [bsz, input_seq_len, attn_seq_len] without query + + """ + device = attention_mask.device + batch_size, seq_len = input_shape + causal_mask = get_causal_attention_mask(seq_len).to(device) + causal_mask = causal_mask.repeat(batch_size, 1).view(batch_size, seq_len, seq_len) + # compare seq_len in input and attention mask + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: + # if query is available, apply causal self-attention mask to control query-text interaction. + # Allow queries attending each other but not the text tokens. + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + dim=1, + ) # mask size [bsz, attn_seq_len, input_seq_len] + # increase causal_mask by prefix_seq_len with 1s to attend self-attention + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + dim=-1, + ) # size of [bsz, attn_seq_len, attn_seq_len] with query, [bsz, input_seq_len, attn_seq_len] without query + return causal_mask From 773a6e6b99ef657da060ed912a2b9bbe256a05a0 Mon Sep 17 00:00:00 2001 From: Peng Chen Date: Wed, 11 Oct 2023 18:08:57 -0700 Subject: [PATCH 2/2] add qformer model to torchmm/models (#486) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/486 as title Differential Revision: D50145316 fbshipit-source-id: 7899538cbbf2903fea96bd951851673f851b0051 --- tests/models/blip2/test_qformer_model.py | 414 ++++++++++++++++++ torchmultimodal/models/blip2/qformer_model.py | 294 +++++++++++++ 2 files changed, 708 insertions(+) create mode 100644 tests/models/blip2/test_qformer_model.py create mode 100644 torchmultimodal/models/blip2/qformer_model.py diff --git a/tests/models/blip2/test_qformer_model.py b/tests/models/blip2/test_qformer_model.py new file mode 100644 index 00000000..ef5479d2 --- /dev/null +++ b/tests/models/blip2/test_qformer_model.py @@ -0,0 +1,414 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torch.nn import CrossEntropyLoss +from torchmultimodal.models.blip2.qformer_model import QformerForCLM, QformerModel + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestQformerModel: + @pytest.fixture + def attn_mask(self): + return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]]) + + @pytest.fixture + def dim_q(self): + return 4 + + @pytest.fixture + def dim_kv(self): + return 2 + + @pytest.fixture + def dim_feedforward(self): + return 6 + + @pytest.fixture + def cross_attention_freq(self): + return 2 + + @pytest.fixture + def num_hidden_layers(self): + return 2 + + @pytest.fixture + def num_heads(self): + return 2 + + @pytest.fixture() + def input_ids(self): + return torch.LongTensor([[0, 1], [2, 3]]) + + @pytest.fixture + def vocab_size(self): + return 20 + + @pytest.fixture() + def query_embeddings(self): + return torch.Tensor( + [ + [ + [0.6424, 0.6182, 0.5110, 0.7867], + [0.3907, 0.2057, 0.6909, 0.6334], + ], + [ + [0.6904, 0.4445, 0.4336, 0.4603], + [0.6318, 0.1163, 0.0340, 0.6871], + ], + ] + ) + + @pytest.fixture + def past_key_value(self): + return torch.Tensor( + [ + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ], + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ], + ] + ) + + @pytest.fixture + def past_key_values(self, past_key_value, num_hidden_layers): + past_key_values = [] + for i in range(num_hidden_layers): + past_key_values.append((past_key_value, past_key_value)) + return past_key_values + + @pytest.fixture + def kv(self): + return torch.Tensor([[[3, 2], [1, 1]], [[3, 2], [1, 1]]]) + + @pytest.fixture + def labels(self): + labels = torch.ones([2, 2]).long() + return labels[:, 1:].contiguous() + + @pytest.fixture + def loss_fct(self): + return CrossEntropyLoss(reduction="mean", label_smoothing=0.1) + + @pytest.fixture + def qformer_model( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + vocab_size, + ): + qformer_model = QformerModel( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + query_length=2, + ) + init_weights_with_constant(qformer_model) + qformer_model.eval() + return qformer_model + + @pytest.fixture + def qformer_model_for_clm( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + vocab_size, + ): + qformer_for_clm = QformerForCLM( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + ) + init_weights_with_constant(qformer_for_clm) + qformer_for_clm.eval() + return qformer_for_clm + + def test_qformer_model_with_attn_mask( + self, + input_ids, + attn_mask, + qformer_model, + query_embeddings, + num_hidden_layers, + kv, + ): + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + ) + expected_hidden_states = torch.Tensor( + [ + [ + [1.0287, 0.7825, -0.3081, 2.4969], + [0.5398, -0.4116, 2.0838, 1.7880], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + [ + [2.7251, 0.4096, 0.3069, 0.5584], + [1.8989, 0.1470, -0.1327, 2.0868], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + ] + ) + assert_expected(actual[0], expected_hidden_states, atol=1e-4, rtol=1e-4) + + assert_expected(len(actual[1]), num_hidden_layers) + assert_expected(len(actual[1][0]), 2) # 2-element tuple includes key and value + assert_expected( + actual[1][0][0].shape, torch.Size([2, 2, 4, 2]) + ) # bsz x num_heads x seq_len x head_dim + expected_cached_values = torch.Tensor( + [ + [ + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + ], + [ + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + ], + ] + ) + assert_expected(actual[1][0][0], expected_cached_values, atol=1e-4, rtol=1e-4) + + def test_qformer_model_with_past_key_values( + self, + input_ids, + qformer_model, + query_embeddings, + num_hidden_layers, + kv, + past_key_values, + ): + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + query_embeds=query_embeddings, + past_key_values=past_key_values, + use_cache=True, + ) + expected_hidden_states = torch.Tensor( + [ + [ + [1.0287, 0.7825, -0.3081, 2.4969], + [0.5398, -0.4116, 2.0838, 1.7880], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + [ + [2.7251, 0.4096, 0.3069, 0.5584], + [1.8989, 0.1470, -0.1327, 2.0868], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + ] + ) + assert_expected(actual[0], expected_hidden_states, atol=1e-4, rtol=1e-4) + + assert_expected(len(actual[1]), num_hidden_layers) + assert_expected(len(actual[1][0]), 2) # 2-element tuple includes key and value + assert_expected( + actual[1][0][0].shape, torch.Size([2, 2, 7, 2]) + ) # bsz x num_heads x seq_len x head_dim + expected_cached_values = torch.Tensor( + [ + [ + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + ], + [ + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + ], + ] + ) + assert_expected(actual[1][0][0], expected_cached_values, atol=1e-4, rtol=1e-4) + + def test_qformer_model_with_causal_mask( + self, + input_ids, + attn_mask, + kv, + qformer_model, + query_embeddings, + num_hidden_layers, + ): + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + use_causal_mask=True, + ) + expected_hidden_states = torch.Tensor( + [ + [ + [1.0287, 0.7825, -0.3081, 2.4969], + [0.5398, -0.4116, 2.0838, 1.7880], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + [ + [2.7251, 0.4096, 0.3069, 0.5584], + [1.8989, 0.1470, -0.1327, 2.0868], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + ] + ) + assert_expected(actual[0], expected_hidden_states, atol=1e-4, rtol=1e-4) + + def test_qformer_model_scripting( + self, qformer_model, input_ids, attn_mask, query_embeddings, kv + ): + scripted_model = torch.jit.script(qformer_model) + scripted_output = scripted_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + ) + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + ) + assert_expected(scripted_output[0], actual[0], atol=1e-4, rtol=1e-4) + assert_expected(scripted_output[1], actual[1], atol=1e-4, rtol=1e-4) + + def test_qformer_for_clm( + self, + qformer_model_for_clm, + query_embeddings, + input_ids, + kv, + attn_mask, + labels, + loss_fct, + vocab_size, + ): + actual_pred = qformer_model_for_clm( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=False, + ) + expected = torch.ones([2, 2, 20]) * 5 + assert_expected(actual_pred, expected, atol=1e-4, rtol=1e-4) + + def test_qformer_for_clm_scripting( + self, + qformer_model_for_clm, + query_embeddings, + input_ids, + kv, + attn_mask, + labels, + loss_fct, + vocab_size, + ): + scripted_model = torch.jit.script(qformer_model_for_clm) + actual_pred = scripted_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=False, + ) + expected = torch.ones([2, 2, 20]) * 5 + assert_expected(actual_pred, expected, atol=1e-4, rtol=1e-4) diff --git a/torchmultimodal/models/blip2/qformer_model.py b/torchmultimodal/models/blip2/qformer_model.py new file mode 100644 index 00000000..8ce3cf68 --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_model.py @@ -0,0 +1,294 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +from torch import nn, Tensor +from torchmultimodal.models.blip2.qformer_layers import QformerEmbedding, QformerEncoder + +from torchmultimodal.models.blip2.qformer_utils import get_causal_mask + + +class QformerModel(nn.Module): + """ + Qformer model including Qformer embedding and Qformer encoder. + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + query_length(int): query length in Qformer, used to compute cached query length. + default value is the same as num_query_token for Blip2 case (https://fburl.com/316803mo). + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA, default is None. + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2. + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + query_length: int = 32, + dim_kv: Optional[int] = None, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + attn_dropout: float = 0.0, + dropout: float = 0.0, + cross_attention_freq=2, + ): + super().__init__() + self.query_length = query_length + self.embeddings = QformerEmbedding( + embedding_dim=dim_q, + max_position_embeddings=max_position_embeddings, + vocab_size=vocab_size, + pad_token_id=pad_token_id, + layer_norm_eps=layer_norm_eps, + dropout=dropout, + ) + self.encoder = QformerEncoder( + num_hidden_layers=num_hidden_layers, + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + layer_norm_eps=layer_norm_eps, + activation=activation, + cross_attention_freq=cross_attention_freq, + dim_kv=dim_kv, + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeds: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + use_cache: bool = False, + use_causal_mask: bool = False, + ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids for QFormer + attention_mask (Optional[Tensor]): attention mask for QFormer + position_ids (Optional[Tensor]): position ids for QFormer + query_embeds (Optional[Tensor]): query embeddings for QFormer + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + past_key_values: (Optional[List[Tuple[Tensor, Tensor]]]): a list of num_layers elements, + each element is a 2-element tuple for cached key/value. + key/value is tensor with shape of (bsz x source_seq_len x embed_dim). + use_cache (bool): whether to use cache for key and value tensors + use_causal_mask (bool): apply causal mask if true, default to False + + Returns: + Qformer encoder output with a tuple of last hidden states and past_key_values if use_cache. + """ + past_seq_length = ( + # overall_seq_length - query_length + past_key_values[0][0].shape[2] - self.query_length + if past_key_values is not None + else 0 + ) + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeddings=query_embeds, + past_seq_length=past_seq_length, + ) + bsz, seq_len = embedding_output.size()[:-1] + + if attention_mask is not None: + if use_causal_mask: + # Apply a causal mask in addition to the padding mask and make attention mask broadcastable. + causal_mask = get_causal_mask( + attention_mask, + (bsz, seq_len), + has_query=(query_embeds is not None), + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype) + else: + attention_mask = attention_mask[:, None, None, :] + # create a tensor which is 0.0 for positions to attend and -10000.0 for masked position. + # use float mask to ensure mask values will be added to the attention weight + attention_mask = (1.0 - attention_mask) * -10000.0 + + return self.encoder( + hidden_states=embedding_output, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + query_length=query_length, + ) + + +class QformerPredictionHead(nn.Module): + """ + MLP head for computinng prediction score from QformerModel output + + Args: + dim_q (int): dimensionality of the query tensor + vocab_size (int): the size of vocabulary used by QFormer + layer_norm_eps (float): the epsilon used by the layer normalization layers, default is 1e-12 + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + """ + + def __init__( + self, + dim_q: int, + vocab_size: int, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.linear_1 = nn.Linear(dim_q, dim_q) + self.activation = activation() + self.layernorm = nn.LayerNorm(dim_q, eps=layer_norm_eps) + self.linear_2 = nn.Linear(dim_q, vocab_size) + + def forward(self, sequence_output: Tensor) -> Tensor: + """ + Inputs (Tensor): + sequence_output of shape bsz x seq_len x embed_dim + Returns: + prediction scores (Tensor) of shape: bsz x seq_len x vocab_size + """ + hidden_states = self.linear_1(sequence_output) + hidden_states = self.activation(hidden_states) + hidden_states = self.layernorm(hidden_states) + predictions = self.linear_2(hidden_states) + return predictions + + +class QformerForCLM(nn.Module): + """ + A QformerModel wrapper class for causal language modeling(clm). + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + query_length(int): query length in Qformer, details see QformerModel class. + dim_kv (Optional[int]): dim_kv (Optional[int]): dimensions of the key and value tensors, this value is only used in CA. + Default is None. + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2 + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + query_length: int = 32, + dim_kv: Optional[int] = None, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.GELU, + attn_dropout: float = 0.0, + dropout: float = 0.0, + cross_attention_freq=2, + ) -> None: + super().__init__() + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + self.head = QformerPredictionHead( + dim_q=dim_q, + activation=activation, + layer_norm_eps=layer_norm_eps, + vocab_size=vocab_size, + ) + self.model = QformerModel( + num_hidden_layers=num_hidden_layers, + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + max_position_embeddings=max_position_embeddings, + vocab_size=vocab_size, + pad_token_id=pad_token_id, + query_length=query_length, + dim_kv=dim_kv, + layer_norm_eps=layer_norm_eps, + activation=activation, + attn_dropout=attn_dropout, + dropout=dropout, + cross_attention_freq=cross_attention_freq, + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeds: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + use_cache: bool = False, + ) -> Tensor: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids for QFormer + attention_mask (Optional[Tensor]): attention mask for QFormer + position_ids (Optional[Tensor]): position ids for QFormer + query_embeds (Optional[Tensor]): query embeddings for QFormer + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + past_key_values: (Optional[List[Tuple[Tensor, Tensor]]]): cached key/value tuple for self-attention + use_cache (bool): whether to use cache for key and value tensors, + default to False for generation as cached values should be computed in previous training tasks. + + Returns: + prediction score (Tensor) computed for next word prediction of shape + bsz x seq_len x vocab_size + """ + # TODO: revisit if it's required for edge cases after BLIP-2 impl. + if past_key_values is not None: + assert query_embeds is None + + sequence_output, _ = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + use_causal_mask=True, # set causal mask for clm + ) + if query_embeds is not None: + sequence_output = sequence_output[:, query_embeds.shape[1] :, :] + + prediction_scores = self.head(sequence_output) + return prediction_scores