Skip to content

Commit

Permalink
port blip2 layer to tmm/models
Browse files Browse the repository at this point in the history
Differential Revision: D50145708

fbshipit-source-id: ea27aa1c71b5485663a2391a3258f79215eb139a
  • Loading branch information
pengchen18 authored and facebook-github-bot committed Oct 11, 2023
1 parent 923e3c9 commit f1c891c
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
137 changes: 137 additions & 0 deletions tests/models/blip2/test_blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pytest
import torch
import torch.nn as nn
from tests.test_utils import assert_expected, init_weights_with_constant
from torchmultimodal.models.blip2.blip2 import BLIP2
from torchmultimodal.models.blip2.qformer_model import QformerForCLM
from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoder


@pytest.fixture
def dim_q():
return 4


@pytest.fixture
def dim_kv():
return 2


@pytest.fixture
def dim_feedforward():
return 6


@pytest.fixture
def num_hidden_layers():
return 2


@pytest.fixture
def num_heads():
return 2


@pytest.fixture
def vocab_size():
return 20


@pytest.fixture
def qformer_model_for_clm(
dim_q,
dim_kv,
dim_feedforward,
num_hidden_layers,
num_heads,
vocab_size,
):
qformer_for_clm = QformerForCLM(
dim_q=dim_q,
dim_kv=dim_kv,
dim_feedforward=dim_feedforward,
num_heads=num_heads,
attn_dropout=0.0,
dropout=0.0,
num_hidden_layers=num_hidden_layers,
max_position_embeddings=512,
vocab_size=vocab_size,
)
return qformer_for_clm


@pytest.fixture
def vit():
embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2)
encoder = TransformerEncoder(
n_layer=1,
d_model=2,
n_head=1,
dim_feedforward=1,
activation=nn.GELU,
norm_first=True,
final_layer_norm_eps=1e-5,
)
image_encoder = VisionTransformer(
embeddings=embedding,
encoder=encoder,
)
init_weights_with_constant(image_encoder)
image_encoder.eval()
return image_encoder


@pytest.fixture
def blip2(dim_q, dim_kv, qformer_model_for_clm, vit):
blip2 = BLIP2(
dim_q=dim_q,
image_encoder_embedding_dim=dim_kv,
qformer=qformer_model_for_clm,
vision_encoder=vit,
embedding_dim=4,
decoder_bos_token_id=19,
)
init_weights_with_constant(blip2)
blip2.eval()
return blip2


@pytest.fixture
def attn_mask():
return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]])


class TestBLIP2:
def test_blip2(self, blip2, attn_mask):
image = torch.ones(2, 3, 2, 2)
input_ids = torch.ones(2, 4).long()
output = blip2(image, input_ids, attn_mask)
assert_expected(
output.image_features, torch.ones([2, 32, 4]) * 0.5, rtol=0, atol=1e-4
)
assert_expected(
output.text_features, torch.ones([2, 4]) * 0.5, rtol=0, atol=1e-4
)
assert_expected(
output.image_embeddings, torch.ones([2, 5, 2]), rtol=0, atol=1e-4
)
assert_expected(
output.prediction_scores, torch.ones([2, 4, 20]) * 5, rtol=0, atol=1e-4
)

def test_blip2_scripting(self, blip2, attn_mask):
image = torch.ones(2, 3, 2, 2)
input_ids = torch.ones(2, 4).long()
scripted_model = torch.jit.script(blip2)
actual = scripted_model(image, input_ids, attn_mask)
expected = blip2(image, input_ids, attn_mask)
assert_expected(actual, expected)
157 changes: 157 additions & 0 deletions torchmultimodal/models/blip2/blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import NamedTuple, Optional

import torch

from torch import nn, Tensor
from torch.nn import functional as F
from torchmultimodal.modules.layers.transformer import TransformerOutput


class Blip2Output(NamedTuple):
"""
BLIP2 model output for loss computation.
image_embeddings(Tensor): normalized image embeddings returned by the visual encoder
with shape [bsz x seq_len x embed_dim].
image_features(Tensor): Image features after qformer and projection (for stage 1 training)
with shape [bsz, num_query_tokens, embed_dim]
image_qformer_output(Tensor) : last hidden state for qformer output by given image input
text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided
with shape [bsz, embed_dim]
prediction_scores (Optional[Tensor]): computed for next word prediction
with shape of [bsz, seq_len, vocab_size]
"""

image_embeddings: Tensor
image_features: Tensor
image_qformer_output: Tensor
text_features: Optional[Tensor] = None
prediction_scores: Optional[Tensor] = None


class BLIP2(nn.Module):
"""
BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language
pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap
and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer
which has a set of learnable query vectors to extract visual features from the frozen image encoder.
Args:
qformer(nn.Module): Querying Transformer (Q-former)
visual_encoder(nn.Module): Frozen image encoder
dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer.
image_encoder_embedding_dim(int): Embedding dimension for image encoder,
this value should be the same as dim_kv in qformer.
freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True
cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2
embedding_dim(int): Embedding dimension
num_query_token(int): Number of query tokens in Qformer, default to 32
init_query_tokens(bool): whether init query token params, default to True
decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None
"""

def __init__(
self,
qformer: nn.Module,
vision_encoder: nn.Module,
dim_q: int,
image_encoder_embedding_dim: int,
freeze_vision_encoder: bool = True,
cross_attention_freq: int = 2,
embedding_dim: int = 256,
num_query_token: int = 32,
init_query_tokens: bool = True,
decoder_bos_token_id: Optional[int] = None,
):
super().__init__()
self.vision_encoder = vision_encoder
if freeze_vision_encoder:
for param in self.vision_encoder.parameters():
param.requires_grad = False
self.vision_encoder = self.vision_encoder.eval()

self.qformer = qformer
self.decoder_bos_token_id = decoder_bos_token_id
self.dim_q = dim_q
self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q))
if init_query_tokens:
self.query_tokens.data.normal_(mean=0.0, std=0.02)

self.vision_proj = nn.Linear(self.dim_q, embedding_dim)
self.text_proj = nn.Linear(self.dim_q, embedding_dim)
self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim)

def forward(
self,
image: Tensor,
input_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
) -> Blip2Output:
"""
Args:
image(Tensor): Image input tensor with shape [B, C, H, W]
input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len]
attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len]
Returns:
return BLIP2 model output(Blip2Output).
"""
vision_encoder_output = self.vision_encoder(image)
if isinstance(vision_encoder_output, TransformerOutput):
vision_encoder_output = vision_encoder_output.last_hidden_state
assert vision_encoder_output is not None
image_embeds = self.ln_vision(vision_encoder_output)
# query tokens: [batch_size, num_query_token, encoder_hidden_size]
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.qformer.model(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
use_cache=True,
)

# image_feats: [batch_size, num_query_token, embedding_dim]
image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1)

text_feats: Optional[Tensor] = None
prediction_scores: Optional[Tensor] = None
if input_ids is not None:
text_output = self.qformer.model(
input_ids,
attention_mask=attention_mask,
use_cache=False,
)
text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1)

decoder_input_ids = input_ids.clone()
if self.decoder_bos_token_id is not None:
# pyre-ignore
decoder_input_ids[:, 0] = self.decoder_bos_token_id

query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
input_ids.device
)
if attention_mask is not None:
attention_mask = torch.cat([query_atts, attention_mask], dim=1)

# set use_cache = False since past_key_values should be cached in previous steps.
prediction_scores = self.qformer(
input_ids=decoder_input_ids,
attention_mask=attention_mask,
past_key_values=query_output[1],
use_cache=False,
)

return Blip2Output(
image_embeddings=image_embeds,
image_features=image_feats,
image_qformer_output=query_output[0],
text_features=text_feats,
prediction_scores=prediction_scores,
)

0 comments on commit f1c891c

Please sign in to comment.