Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning hidden state in CLIPTextEncoder #442

Closed
wants to merge 3 commits into from

Conversation

22quinn
Copy link
Contributor

@22quinn 22quinn commented Aug 4, 2023

Summary:
Added a return_hidden_state arg in CLIPTextEncoder. If set to True, forward will return a tuple of final embedding and the last hidden state.

Test plan:
pytest tests/models/clip/test_text_encoder.py

cc @abhinavarora Moved code to here for easier review and testing.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2023
@codecov-commenter
Copy link

codecov-commenter commented Aug 7, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.05% 🎉

Comparison is base (81e281c) 68.72% compared to head (977e72a) 68.77%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #442      +/-   ##
==========================================
+ Coverage   68.72%   68.77%   +0.05%     
==========================================
  Files         169      169              
  Lines       11374    11393      +19     
==========================================
+ Hits         7817     7836      +19     
  Misses       3557     3557              
Files Changed Coverage Δ
tests/models/clip/test_text_encoder.py 100.00% <100.00%> (ø)
torchmultimodal/models/clip/text_encoder.py 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

)
# embeddings now has size [bs, embedding_dim]

if self.return_hidden_state:
return embeddings, hidden_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use named tuple

assert isinstance(text_encoder, torch.nn.Module)

actual_clip_init, actual_hidden_state = text_encoder(text)
print(actual_hidden_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

expected_hidden_state = torch.Tensor(
[
[
[6.348165e-01, -4.137459e-02, -1.604239e00, 1.010798e00],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we limit to 4 decimal places given atol is being set

@@ -77,6 +82,8 @@ def __init__(
if use_clip_init:
self.initialize_parameters()

self.return_hidden_state = return_hidden_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can make this an arg to fwd

Copy link
Contributor

@abhinavarora abhinavarora Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to what Ankita said!. We don't need this to be a module level property.

@@ -120,11 +127,13 @@ def forward(self, text: Tensor) -> Tensor:

# [n_ctx, bs, transformer.width] -> [bs, n_ctx, transformer.width]
embeddings = torch.permute(embeddings, (1, 0, 2))
embeddings = self.ln_final(embeddings)
hidden_state = self.ln_final(embeddings)
# take features from the eot embedding (the highest number in each sequence)
embeddings = self.projection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest renaming this to something like projected_embeddings to avoid re-using the embeddings name as it was used before.

@22quinn
Copy link
Contributor Author

22quinn commented Aug 8, 2023

Thanks both for the review! I've addressed all comments. buck test looks good.

buck test @mode/dev-nosan //torchmultimodal/tests
Tests finished: Pass 317. Fail 0. Fatal 0. Skip 6. Build failure 0

@ankitade
Copy link
Contributor

ankitade commented Aug 8, 2023

can you import the diff and land it there

@facebook-github-bot
Copy link
Contributor

@abhinavarora has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@abhinavarora merged this pull request in a1cc8f3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants