diff --git a/tests/models/clip/test_text_encoder.py b/tests/models/clip/test_text_encoder.py index 2a5d3bb8..b501148c 100644 --- a/tests/models/clip/test_text_encoder.py +++ b/tests/models/clip/test_text_encoder.py @@ -17,7 +17,7 @@ class TestCLIPTextEncoder: def start(self): set_rng_seed(1234) context_length = 77 - batch_size, embedding_dim, heads = 2, 4, 2 + batch_size, embedding_dim, heads, width = 2, 4, 2, 512 def build_text(text_length): return torch.randint(1, 10, (batch_size, text_length), dtype=torch.long) @@ -27,11 +27,13 @@ def build_encoder( use_clip_init=True, context_length=context_length, heads=heads, + width=width, ): return CLIPTextEncoder( embedding_dim=embedding_dim, use_clip_init=use_clip_init, context_length=context_length, + width=width, heads=heads, ) @@ -117,6 +119,55 @@ def test_forward(self, start): actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4 ) + def test_forward_return_hidden_state(self, start): + build_encoder, build_text = start + text = build_text(text_length=3) + + text_encoder = build_encoder(context_length=3, width=4) + assert isinstance(text_encoder, torch.nn.Module) + + out = text_encoder(text, return_hidden_state=True) + assert ( + hasattr(out, "projected_embeddings") + and hasattr(out, "hidden_state") + and len(out) == 2 + ) + + actual_projected_embeddings = out.projected_embeddings + actual_hidden_state = out.hidden_state + expected_projected_embeddings = torch.Tensor( + [ + [-0.3668, -1.5966, -0.3304, -0.5938], + [-0.7904, 0.8768, -0.9707, -0.7271], + ] + ) + expected_hidden_state = torch.Tensor( + [ + [ + [0.6348, -0.0414, -1.6042, 1.0108], + [0.6205, -0.0303, -1.6066, 1.0164], + [0.5916, -0.0017, -1.6133, 1.0234], + ], + [ + [0.5911, -0.0152, -1.6079, 1.0320], + [0.1468, -1.6758, 0.7402, 0.7888], + [0.6721, -0.2897, -1.4934, 1.1109], + ], + ] + ) + assert_expected( + actual=actual_projected_embeddings, + expected=expected_projected_embeddings, + rtol=0, + atol=1e-4, + ) + assert_expected( + actual=actual_hidden_state, + expected=expected_hidden_state, + rtol=0, + atol=1e-4, + ) + def test_forward_over_context_length(self, start): build_encoder, build_text = start diff --git a/torchmultimodal/models/clip/text_encoder.py b/torchmultimodal/models/clip/text_encoder.py index c1a626e0..10002e34 100644 --- a/torchmultimodal/models/clip/text_encoder.py +++ b/torchmultimodal/models/clip/text_encoder.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +from typing import NamedTuple, Union + import torch from torch import nn, Tensor @@ -13,6 +15,11 @@ from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm +class CLIPTextEncoderOutput(NamedTuple): + projected_embeddings: Tensor + hidden_state: Tensor + + class CLIPTextEncoder(nn.Module): """CLIP text encoder class. Should be instantiated and passed to CLIP (models/clip.py) @@ -108,7 +115,9 @@ def build_attention_mask(self) -> Tensor: ).triu(1) return mask - def forward(self, text: Tensor) -> Tensor: + def forward( + self, text: Tensor, return_hidden_state: bool = False + ) -> Union[Tensor, CLIPTextEncoderOutput]: if text.size(1) != self.context_length: raise ValueError( f"length of input should be {self.context_length} but found {text.size(1)}" @@ -120,11 +129,15 @@ def forward(self, text: Tensor) -> Tensor: # [n_ctx, bs, transformer.width] -> [bs, n_ctx, transformer.width] embeddings = torch.permute(embeddings, (1, 0, 2)) - embeddings = self.ln_final(embeddings) + hidden_state = self.ln_final(embeddings) # take features from the eot embedding (the highest number in each sequence) - embeddings = self.projection( - embeddings[torch.arange(embeddings.shape[0]), text.argmax(dim=-1)] + projected_embeddings = self.projection( + hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)] ) # embeddings now has size [bs, embedding_dim] - return embeddings + if return_hidden_state: + return CLIPTextEncoderOutput( + projected_embeddings=projected_embeddings, hidden_state=hidden_state + ) + return projected_embeddings