From b70dc30bd4858ef949fa71ba4f93024743059ba3 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Fri, 4 Aug 2023 18:41:56 -0500 Subject: [PATCH 1/3] Support returning hidden state from CLIPTextEncoder --- tests/models/clip/test_text_encoder.py | 47 ++++++++++++++++++++- torchmultimodal/models/clip/text_encoder.py | 15 +++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/tests/models/clip/test_text_encoder.py b/tests/models/clip/test_text_encoder.py index 2a5d3bb8..2e2454ce 100644 --- a/tests/models/clip/test_text_encoder.py +++ b/tests/models/clip/test_text_encoder.py @@ -17,7 +17,7 @@ class TestCLIPTextEncoder: def start(self): set_rng_seed(1234) context_length = 77 - batch_size, embedding_dim, heads = 2, 4, 2 + batch_size, embedding_dim, heads, width = 2, 4, 2, 512 def build_text(text_length): return torch.randint(1, 10, (batch_size, text_length), dtype=torch.long) @@ -27,12 +27,16 @@ def build_encoder( use_clip_init=True, context_length=context_length, heads=heads, + width=width, + return_hidden_state=False, ): return CLIPTextEncoder( embedding_dim=embedding_dim, use_clip_init=use_clip_init, context_length=context_length, + width=width, heads=heads, + return_hidden_state=return_hidden_state, ) return build_encoder, build_text @@ -117,6 +121,47 @@ def test_forward(self, start): actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4 ) + def test_forward_return_hidden_state(self, start): + build_encoder, build_text = start + text = build_text(text_length=3) + + text_encoder = build_encoder( + context_length=3, width=4, return_hidden_state=True + ) + assert isinstance(text_encoder, torch.nn.Module) + + actual_clip_init, actual_hidden_state = text_encoder(text) + print(actual_hidden_state) + expected_clip_init = torch.Tensor( + [ + [-0.366838, -1.596611, -0.330413, -0.593790], + [-0.790419, 0.876780, -0.970667, -0.727134], + ] + ) + expected_hidden_state = torch.Tensor( + [ + [ + [6.348165e-01, -4.137459e-02, -1.604239e00, 1.010798e00], + [6.204837e-01, -3.028658e-02, -1.606570e00, 1.016373e00], + [5.915626e-01, -1.666874e-03, -1.613292e00, 1.023396e00], + ], + [ + [5.910631e-01, -1.515219e-02, -1.607913e00, 1.032002e00], + [1.467783e-01, -1.675803e00, 7.402021e-01, 7.888227e-01], + [6.721084e-01, -2.896671e-01, -1.493379e00, 1.110938e00], + ], + ] + ) + assert_expected( + actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4 + ) + assert_expected( + actual=actual_hidden_state, + expected=expected_hidden_state, + rtol=0, + atol=1e-4, + ) + def test_forward_over_context_length(self, start): build_encoder, build_text = start diff --git a/torchmultimodal/models/clip/text_encoder.py b/torchmultimodal/models/clip/text_encoder.py index c1a626e0..5094861c 100644 --- a/torchmultimodal/models/clip/text_encoder.py +++ b/torchmultimodal/models/clip/text_encoder.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +from typing import Tuple, Union + import torch from torch import nn, Tensor @@ -28,6 +30,8 @@ class CLIPTextEncoder(nn.Module): heads (int): Number of heads in Transformer encoder. layers (int): Number of layers in Transformer encoder. use_clip_init (bool): Whether to use CLIP-specific initialization. + return_hidden_state (bool): Whether to return the last hidden state. + If True, forward returns a tuple of final embedding and last hidden state. Inputs: text (Tensor): Tensor containing text features. """ @@ -45,6 +49,7 @@ def __init__( heads: int = 8, layers: int = 12, use_clip_init: bool = True, + return_hidden_state: bool = False, ): super().__init__() torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") @@ -77,6 +82,8 @@ def __init__( if use_clip_init: self.initialize_parameters() + self.return_hidden_state = return_hidden_state + def initialize_parameters(self) -> None: # Initialize token and positional embeddings nn.init.normal_(self.token_embedding.weight, std=self.TOKEN_EMBEDDING_INIT_STD) @@ -108,7 +115,7 @@ def build_attention_mask(self) -> Tensor: ).triu(1) return mask - def forward(self, text: Tensor) -> Tensor: + def forward(self, text: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: if text.size(1) != self.context_length: raise ValueError( f"length of input should be {self.context_length} but found {text.size(1)}" @@ -120,11 +127,13 @@ def forward(self, text: Tensor) -> Tensor: # [n_ctx, bs, transformer.width] -> [bs, n_ctx, transformer.width] embeddings = torch.permute(embeddings, (1, 0, 2)) - embeddings = self.ln_final(embeddings) + hidden_state = self.ln_final(embeddings) # take features from the eot embedding (the highest number in each sequence) embeddings = self.projection( - embeddings[torch.arange(embeddings.shape[0]), text.argmax(dim=-1)] + hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)] ) # embeddings now has size [bs, embedding_dim] + if self.return_hidden_state: + return embeddings, hidden_state return embeddings From 0c4307ddf8303b21716d50e41b30c2d53b864a6d Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Mon, 7 Aug 2023 23:33:25 -0500 Subject: [PATCH 2/3] Make return_hidden_state an arg for forward and add CLIPTextEncoderOutput --- tests/models/clip/test_text_encoder.py | 40 ++++++++++++--------- torchmultimodal/models/clip/text_encoder.py | 26 ++++++++------ 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/tests/models/clip/test_text_encoder.py b/tests/models/clip/test_text_encoder.py index 2e2454ce..b501148c 100644 --- a/tests/models/clip/test_text_encoder.py +++ b/tests/models/clip/test_text_encoder.py @@ -28,7 +28,6 @@ def build_encoder( context_length=context_length, heads=heads, width=width, - return_hidden_state=False, ): return CLIPTextEncoder( embedding_dim=embedding_dim, @@ -36,7 +35,6 @@ def build_encoder( context_length=context_length, width=width, heads=heads, - return_hidden_state=return_hidden_state, ) return build_encoder, build_text @@ -125,35 +123,43 @@ def test_forward_return_hidden_state(self, start): build_encoder, build_text = start text = build_text(text_length=3) - text_encoder = build_encoder( - context_length=3, width=4, return_hidden_state=True - ) + text_encoder = build_encoder(context_length=3, width=4) assert isinstance(text_encoder, torch.nn.Module) - actual_clip_init, actual_hidden_state = text_encoder(text) - print(actual_hidden_state) - expected_clip_init = torch.Tensor( + out = text_encoder(text, return_hidden_state=True) + assert ( + hasattr(out, "projected_embeddings") + and hasattr(out, "hidden_state") + and len(out) == 2 + ) + + actual_projected_embeddings = out.projected_embeddings + actual_hidden_state = out.hidden_state + expected_projected_embeddings = torch.Tensor( [ - [-0.366838, -1.596611, -0.330413, -0.593790], - [-0.790419, 0.876780, -0.970667, -0.727134], + [-0.3668, -1.5966, -0.3304, -0.5938], + [-0.7904, 0.8768, -0.9707, -0.7271], ] ) expected_hidden_state = torch.Tensor( [ [ - [6.348165e-01, -4.137459e-02, -1.604239e00, 1.010798e00], - [6.204837e-01, -3.028658e-02, -1.606570e00, 1.016373e00], - [5.915626e-01, -1.666874e-03, -1.613292e00, 1.023396e00], + [0.6348, -0.0414, -1.6042, 1.0108], + [0.6205, -0.0303, -1.6066, 1.0164], + [0.5916, -0.0017, -1.6133, 1.0234], ], [ - [5.910631e-01, -1.515219e-02, -1.607913e00, 1.032002e00], - [1.467783e-01, -1.675803e00, 7.402021e-01, 7.888227e-01], - [6.721084e-01, -2.896671e-01, -1.493379e00, 1.110938e00], + [0.5911, -0.0152, -1.6079, 1.0320], + [0.1468, -1.6758, 0.7402, 0.7888], + [0.6721, -0.2897, -1.4934, 1.1109], ], ] ) assert_expected( - actual=actual_clip_init, expected=expected_clip_init, rtol=0, atol=1e-4 + actual=actual_projected_embeddings, + expected=expected_projected_embeddings, + rtol=0, + atol=1e-4, ) assert_expected( actual=actual_hidden_state, diff --git a/torchmultimodal/models/clip/text_encoder.py b/torchmultimodal/models/clip/text_encoder.py index 5094861c..d347f245 100644 --- a/torchmultimodal/models/clip/text_encoder.py +++ b/torchmultimodal/models/clip/text_encoder.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Tuple, Union +from typing import NamedTuple, Union import torch from torch import nn, Tensor @@ -15,6 +15,11 @@ from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm +class CLIPTextEncoderOutput(NamedTuple): + projected_embeddings: torch.Tensor + hidden_state: torch.Tensor + + class CLIPTextEncoder(nn.Module): """CLIP text encoder class. Should be instantiated and passed to CLIP (models/clip.py) @@ -30,8 +35,6 @@ class CLIPTextEncoder(nn.Module): heads (int): Number of heads in Transformer encoder. layers (int): Number of layers in Transformer encoder. use_clip_init (bool): Whether to use CLIP-specific initialization. - return_hidden_state (bool): Whether to return the last hidden state. - If True, forward returns a tuple of final embedding and last hidden state. Inputs: text (Tensor): Tensor containing text features. """ @@ -49,7 +52,6 @@ def __init__( heads: int = 8, layers: int = 12, use_clip_init: bool = True, - return_hidden_state: bool = False, ): super().__init__() torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") @@ -82,8 +84,6 @@ def __init__( if use_clip_init: self.initialize_parameters() - self.return_hidden_state = return_hidden_state - def initialize_parameters(self) -> None: # Initialize token and positional embeddings nn.init.normal_(self.token_embedding.weight, std=self.TOKEN_EMBEDDING_INIT_STD) @@ -115,7 +115,9 @@ def build_attention_mask(self) -> Tensor: ).triu(1) return mask - def forward(self, text: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def forward( + self, text: Tensor, return_hidden_state: bool = False + ) -> Union[Tensor, CLIPTextEncoderOutput]: if text.size(1) != self.context_length: raise ValueError( f"length of input should be {self.context_length} but found {text.size(1)}" @@ -129,11 +131,13 @@ def forward(self, text: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: embeddings = torch.permute(embeddings, (1, 0, 2)) hidden_state = self.ln_final(embeddings) # take features from the eot embedding (the highest number in each sequence) - embeddings = self.projection( + projected_embeddings = self.projection( hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)] ) # embeddings now has size [bs, embedding_dim] - if self.return_hidden_state: - return embeddings, hidden_state - return embeddings + if return_hidden_state: + return CLIPTextEncoderOutput( + projected_embeddings=projected_embeddings, hidden_state=hidden_state + ) + return projected_embeddings From 977e72a187aef240ac7dacba972011d94e300989 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Tue, 8 Aug 2023 01:29:13 -0500 Subject: [PATCH 3/3] torch.Tensor -> Tensor --- torchmultimodal/models/clip/text_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmultimodal/models/clip/text_encoder.py b/torchmultimodal/models/clip/text_encoder.py index d347f245..10002e34 100644 --- a/torchmultimodal/models/clip/text_encoder.py +++ b/torchmultimodal/models/clip/text_encoder.py @@ -16,8 +16,8 @@ class CLIPTextEncoderOutput(NamedTuple): - projected_embeddings: torch.Tensor - hidden_state: torch.Tensor + projected_embeddings: Tensor + hidden_state: Tensor class CLIPTextEncoder(nn.Module):