Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning hidden state in CLIPTextEncoder #442

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion tests/models/clip/test_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestCLIPTextEncoder:
def start(self):
set_rng_seed(1234)
context_length = 77
batch_size, embedding_dim, heads = 2, 4, 2
batch_size, embedding_dim, heads, width = 2, 4, 2, 512

def build_text(text_length):
return torch.randint(1, 10, (batch_size, text_length), dtype=torch.long)
Expand All @@ -27,12 +27,16 @@ def build_encoder(
use_clip_init=True,
context_length=context_length,
heads=heads,
width=width,
return_hidden_state=False,
):
return CLIPTextEncoder(
embedding_dim=embedding_dim,
use_clip_init=use_clip_init,
context_length=context_length,
width=width,
heads=heads,
return_hidden_state=return_hidden_state,
)

return build_encoder, build_text
Expand Down Expand Up @@ -117,6 +121,47 @@ def test_forward(self, start):
actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4
)

def test_forward_return_hidden_state(self, start):
build_encoder, build_text = start
text = build_text(text_length=3)

text_encoder = build_encoder(
context_length=3, width=4, return_hidden_state=True
)
assert isinstance(text_encoder, torch.nn.Module)

actual_clip_init, actual_hidden_state = text_encoder(text)
print(actual_hidden_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

expected_clip_init = torch.Tensor(
[
[-0.366838, -1.596611, -0.330413, -0.593790],
[-0.790419, 0.876780, -0.970667, -0.727134],
]
)
expected_hidden_state = torch.Tensor(
[
[
[6.348165e-01, -4.137459e-02, -1.604239e00, 1.010798e00],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we limit to 4 decimal places given atol is being set

[6.204837e-01, -3.028658e-02, -1.606570e00, 1.016373e00],
[5.915626e-01, -1.666874e-03, -1.613292e00, 1.023396e00],
],
[
[5.910631e-01, -1.515219e-02, -1.607913e00, 1.032002e00],
[1.467783e-01, -1.675803e00, 7.402021e-01, 7.888227e-01],
[6.721084e-01, -2.896671e-01, -1.493379e00, 1.110938e00],
],
]
)
assert_expected(
actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4
)
assert_expected(
actual=actual_hidden_state,
expected=expected_hidden_state,
rtol=0,
atol=1e-4,
)

def test_forward_over_context_length(self, start):
build_encoder, build_text = start

Expand Down
15 changes: 12 additions & 3 deletions torchmultimodal/models/clip/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.


from typing import Tuple, Union

import torch
from torch import nn, Tensor

Expand All @@ -28,6 +30,8 @@ class CLIPTextEncoder(nn.Module):
heads (int): Number of heads in Transformer encoder.
layers (int): Number of layers in Transformer encoder.
use_clip_init (bool): Whether to use CLIP-specific initialization.
return_hidden_state (bool): Whether to return the last hidden state.
If True, forward returns a tuple of final embedding and last hidden state.

Inputs: text (Tensor): Tensor containing text features.
"""
Expand All @@ -45,6 +49,7 @@ def __init__(
heads: int = 8,
layers: int = 12,
use_clip_init: bool = True,
return_hidden_state: bool = False,
):
super().__init__()
torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}")
Expand Down Expand Up @@ -77,6 +82,8 @@ def __init__(
if use_clip_init:
self.initialize_parameters()

self.return_hidden_state = return_hidden_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can make this an arg to fwd

Copy link
Contributor

@abhinavarora abhinavarora Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to what Ankita said!. We don't need this to be a module level property.


def initialize_parameters(self) -> None:
# Initialize token and positional embeddings
nn.init.normal_(self.token_embedding.weight, std=self.TOKEN_EMBEDDING_INIT_STD)
Expand Down Expand Up @@ -108,7 +115,7 @@ def build_attention_mask(self) -> Tensor:
).triu(1)
return mask

def forward(self, text: Tensor) -> Tensor:
def forward(self, text: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if text.size(1) != self.context_length:
raise ValueError(
f"length of input should be {self.context_length} but found {text.size(1)}"
Expand All @@ -120,11 +127,13 @@ def forward(self, text: Tensor) -> Tensor:

# [n_ctx, bs, transformer.width] -> [bs, n_ctx, transformer.width]
embeddings = torch.permute(embeddings, (1, 0, 2))
embeddings = self.ln_final(embeddings)
hidden_state = self.ln_final(embeddings)
# take features from the eot embedding (the highest number in each sequence)
embeddings = self.projection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest renaming this to something like projected_embeddings to avoid re-using the embeddings name as it was used before.

embeddings[torch.arange(embeddings.shape[0]), text.argmax(dim=-1)]
hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)]
)
# embeddings now has size [bs, embedding_dim]

if self.return_hidden_state:
return embeddings, hidden_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use named tuple

return embeddings
Loading