Skip to content

Commit

Permalink
Support returning hidden state from CLIPTextEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
22quinn committed Aug 4, 2023
1 parent 81e281c commit b70dc30
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
47 changes: 46 additions & 1 deletion tests/models/clip/test_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestCLIPTextEncoder:
def start(self):
set_rng_seed(1234)
context_length = 77
batch_size, embedding_dim, heads = 2, 4, 2
batch_size, embedding_dim, heads, width = 2, 4, 2, 512

def build_text(text_length):
return torch.randint(1, 10, (batch_size, text_length), dtype=torch.long)
Expand All @@ -27,12 +27,16 @@ def build_encoder(
use_clip_init=True,
context_length=context_length,
heads=heads,
width=width,
return_hidden_state=False,
):
return CLIPTextEncoder(
embedding_dim=embedding_dim,
use_clip_init=use_clip_init,
context_length=context_length,
width=width,
heads=heads,
return_hidden_state=return_hidden_state,
)

return build_encoder, build_text
Expand Down Expand Up @@ -117,6 +121,47 @@ def test_forward(self, start):
actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4
)

def test_forward_return_hidden_state(self, start):
build_encoder, build_text = start
text = build_text(text_length=3)

text_encoder = build_encoder(
context_length=3, width=4, return_hidden_state=True
)
assert isinstance(text_encoder, torch.nn.Module)

actual_clip_init, actual_hidden_state = text_encoder(text)
print(actual_hidden_state)
expected_clip_init = torch.Tensor(
[
[-0.366838, -1.596611, -0.330413, -0.593790],
[-0.790419, 0.876780, -0.970667, -0.727134],
]
)
expected_hidden_state = torch.Tensor(
[
[
[6.348165e-01, -4.137459e-02, -1.604239e00, 1.010798e00],
[6.204837e-01, -3.028658e-02, -1.606570e00, 1.016373e00],
[5.915626e-01, -1.666874e-03, -1.613292e00, 1.023396e00],
],
[
[5.910631e-01, -1.515219e-02, -1.607913e00, 1.032002e00],
[1.467783e-01, -1.675803e00, 7.402021e-01, 7.888227e-01],
[6.721084e-01, -2.896671e-01, -1.493379e00, 1.110938e00],
],
]
)
assert_expected(
actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4
)
assert_expected(
actual=actual_hidden_state,
expected=expected_hidden_state,
rtol=0,
atol=1e-4,
)

def test_forward_over_context_length(self, start):
build_encoder, build_text = start

Expand Down
15 changes: 12 additions & 3 deletions torchmultimodal/models/clip/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.


from typing import Tuple, Union

import torch
from torch import nn, Tensor

Expand All @@ -28,6 +30,8 @@ class CLIPTextEncoder(nn.Module):
heads (int): Number of heads in Transformer encoder.
layers (int): Number of layers in Transformer encoder.
use_clip_init (bool): Whether to use CLIP-specific initialization.
return_hidden_state (bool): Whether to return the last hidden state.
If True, forward returns a tuple of final embedding and last hidden state.
Inputs: text (Tensor): Tensor containing text features.
"""
Expand All @@ -45,6 +49,7 @@ def __init__(
heads: int = 8,
layers: int = 12,
use_clip_init: bool = True,
return_hidden_state: bool = False,
):
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
Expand Down Expand Up @@ -77,6 +82,8 @@ def __init__(
if use_clip_init:
self.initialize_parameters()

self.return_hidden_state = return_hidden_state

def initialize_parameters(self) -> None:
# Initialize token and positional embeddings
nn.init.normal_(self.token_embedding.weight, std=self.TOKEN_EMBEDDING_INIT_STD)
Expand Down Expand Up @@ -108,7 +115,7 @@ def build_attention_mask(self) -> Tensor:
).triu(1)
return mask

def forward(self, text: Tensor) -> Tensor:
def forward(self, text: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if text.size(1) != self.context_length:
raise ValueError(
f"length of input should be {self.context_length} but found {text.size(1)}"
Expand All @@ -120,11 +127,13 @@ def forward(self, text: Tensor) -> Tensor:

# [n_ctx, bs, transformer.width] -> [bs, n_ctx, transformer.width]
embeddings = torch.permute(embeddings, (1, 0, 2))
embeddings = self.ln_final(embeddings)
hidden_state = self.ln_final(embeddings)
# take features from the eot embedding (the highest number in each sequence)
embeddings = self.projection(
embeddings[torch.arange(embeddings.shape[0]), text.argmax(dim=-1)]
hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)]
)
# embeddings now has size [bs, embedding_dim]

if self.return_hidden_state:
return embeddings, hidden_state
return embeddings

0 comments on commit b70dc30

Please sign in to comment.