Skip to content

Commit

Permalink
Support returning hidden state in CLIPTextEncoder (#442)
Browse files Browse the repository at this point in the history
Summary:
Added a `return_hidden_state` arg in CLIPTextEncoder. If set to True, forward will return a tuple of final embedding and the last hidden state.

Pull Request resolved: #442

Test Plan:
pytest tests/models/clip/test_text_encoder.py

cc abhinavarora Moved code to here for easier review and testing.

Reviewed By: ankitade

Differential Revision: D48196191

Pulled By: abhinavarora

fbshipit-source-id: 856819378b607f67aab571b7e787206b6db7b7cc
  • Loading branch information
22quinn authored and facebook-github-bot committed Aug 10, 2023
1 parent 81e281c commit a1cc8f3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
48 changes: 47 additions & 1 deletion tests/models/clip/test_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestCLIPTextEncoder:
def start(self):
set_rng_seed(1234)
context_length = 77
batch_size, embedding_dim, heads = 2, 4, 2
batch_size, embedding_dim, heads, width = 2, 4, 2, 512

def build_text(text_length):
return torch.randint(1, 10, (batch_size, text_length), dtype=torch.long)
Expand All @@ -27,11 +27,13 @@ def build_encoder(
use_clip_init=True,
context_length=context_length,
heads=heads,
width=width,
):
return CLIPTextEncoder(
embedding_dim=embedding_dim,
use_clip_init=use_clip_init,
context_length=context_length,
width=width,
heads=heads,
)

Expand Down Expand Up @@ -117,6 +119,50 @@ def test_forward(self, start):
actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4
)

def test_forward_return_hidden_state(self, start):
build_encoder, build_text = start
text = build_text(text_length=3)

text_encoder = build_encoder(context_length=3, width=4)
assert isinstance(text_encoder, torch.nn.Module)

out = text_encoder(text, return_hidden_state=True)

actual_projected_embeddings = out.projected_embeddings
actual_hidden_state = out.hidden_state
expected_projected_embeddings = torch.Tensor(
[
[-0.3668, -1.5966, -0.3304, -0.5938],
[-0.7904, 0.8768, -0.9707, -0.7271],
]
)
expected_hidden_state = torch.Tensor(
[
[
[0.6348, -0.0414, -1.6042, 1.0108],
[0.6205, -0.0303, -1.6066, 1.0164],
[0.5916, -0.0017, -1.6133, 1.0234],
],
[
[0.5911, -0.0152, -1.6079, 1.0320],
[0.1468, -1.6758, 0.7402, 0.7888],
[0.6721, -0.2897, -1.4934, 1.1109],
],
]
)
assert_expected(
actual=actual_projected_embeddings,
expected=expected_projected_embeddings,
rtol=0,
atol=1e-4,
)
assert_expected(
actual=actual_hidden_state,
expected=expected_hidden_state,
rtol=0,
atol=1e-4,
)

def test_forward_over_context_length(self, start):
build_encoder, build_text = start

Expand Down
23 changes: 18 additions & 5 deletions torchmultimodal/models/clip/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.


from typing import NamedTuple, Union

import torch
from torch import nn, Tensor

Expand All @@ -13,6 +15,11 @@
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm


class CLIPTextEncoderOutput(NamedTuple):
projected_embeddings: Tensor
hidden_state: Tensor


class CLIPTextEncoder(nn.Module):
"""CLIP text encoder class. Should be instantiated and passed to
CLIP (models/clip.py)
Expand Down Expand Up @@ -108,7 +115,9 @@ def build_attention_mask(self) -> Tensor:
).triu(1)
return mask

def forward(self, text: Tensor) -> Tensor:
def forward(
self, text: Tensor, return_hidden_state: bool = False
) -> Union[Tensor, CLIPTextEncoderOutput]:
if text.size(1) != self.context_length:
raise ValueError(
f"length of input should be {self.context_length} but found {text.size(1)}"
Expand All @@ -120,11 +129,15 @@ def forward(self, text: Tensor) -> Tensor:

# [n_ctx, bs, transformer.width] -> [bs, n_ctx, transformer.width]
embeddings = torch.permute(embeddings, (1, 0, 2))
embeddings = self.ln_final(embeddings)
hidden_state = self.ln_final(embeddings)
# take features from the eot embedding (the highest number in each sequence)
embeddings = self.projection(
embeddings[torch.arange(embeddings.shape[0]), text.argmax(dim=-1)]
projected_embeddings = self.projection(
hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)]
)
# embeddings now has size [bs, embedding_dim]

return embeddings
if return_hidden_state:
return CLIPTextEncoderOutput(
projected_embeddings=projected_embeddings, hidden_state=hidden_state
)
return projected_embeddings

0 comments on commit a1cc8f3

Please sign in to comment.