-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add blip2 layer under torchmm/models
Summary: as title Differential Revision: D50145708 fbshipit-source-id: 1a0cba1551bbd8804dc0fc1cdf74850e8e055679
- Loading branch information
1 parent
83dae46
commit 6dcd8bb
Showing
2 changed files
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from tests.test_utils import assert_expected, init_weights_with_constant | ||
from torchmultimodal.models.blip2.blip2 import BLIP2 | ||
from torchmultimodal.models.blip2.qformer_model import QformerForCLM | ||
from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer | ||
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings | ||
from torchmultimodal.modules.layers.transformer import TransformerEncoder | ||
|
||
|
||
@pytest.fixture | ||
def dim_q(): | ||
return 4 | ||
|
||
|
||
@pytest.fixture | ||
def dim_kv(): | ||
return 2 | ||
|
||
|
||
@pytest.fixture | ||
def dim_feedforward(): | ||
return 6 | ||
|
||
|
||
@pytest.fixture | ||
def num_hidden_layers(): | ||
return 2 | ||
|
||
|
||
@pytest.fixture | ||
def num_heads(): | ||
return 2 | ||
|
||
|
||
@pytest.fixture | ||
def vocab_size(): | ||
return 20 | ||
|
||
|
||
@pytest.fixture | ||
def qformer_model_for_clm( | ||
dim_q, | ||
dim_kv, | ||
dim_feedforward, | ||
num_hidden_layers, | ||
num_heads, | ||
vocab_size, | ||
): | ||
qformer_for_clm = QformerForCLM( | ||
dim_q=dim_q, | ||
dim_kv=dim_kv, | ||
dim_feedforward=dim_feedforward, | ||
num_heads=num_heads, | ||
attn_dropout=0.0, | ||
dropout=0.0, | ||
num_hidden_layers=num_hidden_layers, | ||
max_position_embeddings=512, | ||
vocab_size=vocab_size, | ||
) | ||
return qformer_for_clm | ||
|
||
|
||
@pytest.fixture | ||
def vit(): | ||
embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2) | ||
encoder = TransformerEncoder( | ||
n_layer=1, | ||
d_model=2, | ||
n_head=1, | ||
dim_feedforward=1, | ||
activation=nn.GELU, | ||
norm_first=True, | ||
final_layer_norm_eps=1e-5, | ||
) | ||
image_encoder = VisionTransformer( | ||
embeddings=embedding, | ||
encoder=encoder, | ||
) | ||
init_weights_with_constant(image_encoder) | ||
image_encoder.eval() | ||
return image_encoder | ||
|
||
|
||
@pytest.fixture | ||
def blip2(dim_q, dim_kv, qformer_model_for_clm, vit): | ||
blip2 = BLIP2( | ||
dim_q=dim_q, | ||
image_encoder_embedding_dim=dim_kv, | ||
qformer=qformer_model_for_clm, | ||
vision_encoder=vit, | ||
embedding_dim=4, | ||
decoder_bos_token_id=19, | ||
) | ||
init_weights_with_constant(blip2) | ||
blip2.eval() | ||
return blip2 | ||
|
||
|
||
@pytest.fixture | ||
def attn_mask(): | ||
return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]]) | ||
|
||
|
||
class TestBLIP2: | ||
def test_blip2(self, blip2, attn_mask): | ||
image = torch.ones(2, 3, 2, 2) | ||
input_ids = torch.ones(2, 4).long() | ||
output = blip2(image, input_ids, attn_mask) | ||
assert_expected( | ||
output.image_features, torch.ones([2, 32, 4]) * 0.5, rtol=0, atol=1e-4 | ||
) | ||
assert_expected( | ||
output.text_features, torch.ones([2, 4]) * 0.5, rtol=0, atol=1e-4 | ||
) | ||
assert_expected( | ||
output.image_embeddings, torch.ones([2, 5, 2]), rtol=0, atol=1e-4 | ||
) | ||
assert_expected( | ||
output.prediction_scores, torch.ones([2, 4, 20]) * 5, rtol=0, atol=1e-4 | ||
) | ||
|
||
def test_blip2_scripting(self, blip2, attn_mask): | ||
image = torch.ones(2, 3, 2, 2) | ||
input_ids = torch.ones(2, 4).long() | ||
scripted_model = torch.jit.script(blip2) | ||
actual = scripted_model(image, input_ids, attn_mask) | ||
expected = blip2(image, input_ids, attn_mask) | ||
assert_expected(actual, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import NamedTuple, Optional | ||
|
||
import torch | ||
|
||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
from torchmultimodal.modules.layers.transformer import TransformerOutput | ||
|
||
|
||
class Blip2Output(NamedTuple): | ||
""" | ||
BLIP2 model output for loss computation. | ||
image_embeddings(Tensor): normalized image embeddings returned by the visual encoder | ||
with shape [bsz x seq_len x embed_dim]. | ||
image_features(Tensor): Image features after qformer and projection (for stage 1 training) | ||
with shape [bsz, num_query_tokens, embed_dim] | ||
image_qformer_output(Tensor) : last hidden state for qformer output by given image input | ||
text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided | ||
with shape [bsz, embed_dim] | ||
prediction_scores (Optional[Tensor]): computed for next word prediction | ||
with shape of [bsz, seq_len, vocab_size] | ||
""" | ||
|
||
image_embeddings: Tensor | ||
image_features: Tensor | ||
image_qformer_output: Tensor | ||
text_features: Optional[Tensor] = None | ||
prediction_scores: Optional[Tensor] = None | ||
|
||
|
||
class BLIP2(nn.Module): | ||
""" | ||
BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language | ||
pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap | ||
and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer | ||
which has a set of learnable query vectors to extract visual features from the frozen image encoder. | ||
Args: | ||
qformer(nn.Module): Querying Transformer (Q-former) | ||
visual_encoder(nn.Module): Frozen image encoder | ||
dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer. | ||
image_encoder_embedding_dim(int): Embedding dimension for image encoder, | ||
this value should be the same as dim_kv in qformer. | ||
freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True | ||
cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2 | ||
embedding_dim(int): Embedding dimension | ||
num_query_token(int): Number of query tokens in Qformer, default to 32 | ||
init_query_tokens(bool): whether init query token params, default to True | ||
decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None | ||
""" | ||
|
||
def __init__( | ||
self, | ||
qformer: nn.Module, | ||
vision_encoder: nn.Module, | ||
dim_q: int, | ||
image_encoder_embedding_dim: int, | ||
freeze_vision_encoder: bool = True, | ||
cross_attention_freq: int = 2, | ||
embedding_dim: int = 256, | ||
num_query_token: int = 32, | ||
init_query_tokens: bool = True, | ||
decoder_bos_token_id: Optional[int] = None, | ||
): | ||
super().__init__() | ||
self.vision_encoder = vision_encoder | ||
if freeze_vision_encoder: | ||
for param in self.vision_encoder.parameters(): | ||
param.requires_grad = False | ||
self.vision_encoder = self.vision_encoder.eval() | ||
|
||
self.qformer = qformer | ||
self.decoder_bos_token_id = decoder_bos_token_id | ||
self.dim_q = dim_q | ||
self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q)) | ||
if init_query_tokens: | ||
self.query_tokens.data.normal_(mean=0.0, std=0.02) | ||
|
||
self.vision_proj = nn.Linear(self.dim_q, embedding_dim) | ||
self.text_proj = nn.Linear(self.dim_q, embedding_dim) | ||
self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim) | ||
|
||
def forward( | ||
self, | ||
image: Tensor, | ||
input_ids: Optional[Tensor] = None, | ||
attention_mask: Optional[Tensor] = None, | ||
) -> Blip2Output: | ||
""" | ||
Args: | ||
image(Tensor): Image input tensor with shape [B, C, H, W] | ||
input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len] | ||
attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len] | ||
Returns: | ||
return BLIP2 model output(Blip2Output). | ||
""" | ||
vision_encoder_output = self.vision_encoder(image) | ||
if isinstance(vision_encoder_output, TransformerOutput): | ||
vision_encoder_output = vision_encoder_output.last_hidden_state | ||
assert vision_encoder_output is not None | ||
image_embeds = self.ln_vision(vision_encoder_output) | ||
# query tokens: [batch_size, num_query_token, encoder_hidden_size] | ||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | ||
query_output = self.qformer.model( | ||
query_embeds=query_tokens, | ||
encoder_hidden_states=image_embeds, | ||
use_cache=True, | ||
) | ||
|
||
# image_feats: [batch_size, num_query_token, embedding_dim] | ||
image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1) | ||
|
||
text_feats: Optional[Tensor] = None | ||
prediction_scores: Optional[Tensor] = None | ||
if input_ids is not None: | ||
text_output = self.qformer.model( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
use_cache=False, | ||
) | ||
text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1) | ||
|
||
decoder_input_ids = input_ids.clone() | ||
if self.decoder_bos_token_id is not None: | ||
# pyre-ignore | ||
decoder_input_ids[:, 0] = self.decoder_bos_token_id | ||
|
||
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( | ||
input_ids.device | ||
) | ||
if attention_mask is not None: | ||
attention_mask = torch.cat([query_atts, attention_mask], dim=1) | ||
|
||
# set use_cache = False since past_key_values should be cached in previous steps. | ||
prediction_scores = self.qformer( | ||
input_ids=decoder_input_ids, | ||
attention_mask=attention_mask, | ||
past_key_values=query_output[1], | ||
use_cache=False, | ||
) | ||
|
||
return Blip2Output( | ||
image_embeddings=image_embeds, | ||
image_features=image_feats, | ||
image_qformer_output=query_output[0], | ||
text_features=text_feats, | ||
prediction_scores=prediction_scores, | ||
) |