Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for CoCa cascaded attention poolers #518

Closed
wants to merge 2 commits into from
Closed

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Jan 3, 2024

A couple fixes to CoCa's attention pooling as pointed out in #517. Specifically, we need to change the input dim for the contrastive pooler to match the output dim from the captioning pooler in the case of cascaded attention pooling. We should also set n_queries=1 for the contrastive pooler so that the pooled embeddings can be directly fed into contrastive loss (after appropriate normalization).

Test:

from torchmultimodal.models.coca.coca_model import coca_vit_l_14
model = coca_vit_l_14()
bs, c, h, w, seq_len, vocab_size = 2, 3, 224, 224, 77, 49408
images = torch.randn(bs, c, h, w) 
texts = torch.randint(0, vocab_size, (bs, seq_len))
out = model(images, texts)
print(out.image_pooled_output.shape, out.multimodal_embeddings.shape)
...
torch.Size([2, 1, 768]) torch.Size([2, 76, 49408])

Add new unit test:

python -m pytest -v tests/models/coca/test_coca_model.py
...
===== 4 passed in 3.18s ======

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 3, 2024
@codecov-commenter
Copy link

codecov-commenter commented Jan 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (fc92cea) 75.57% compared to head (2a2e742) 75.61%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #518      +/-   ##
==========================================
+ Coverage   75.57%   75.61%   +0.04%     
==========================================
  Files         234      234              
  Lines       16113    16122       +9     
==========================================
+ Hits        12177    12191      +14     
+ Misses       3936     3931       -5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it addresses the issue. I'll approve it as it looks good, but it might be worth updating a unit test to ensure an issue like this isn't missed going forward.

@facebook-github-bot
Copy link
Contributor

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ebsmothers merged this pull request in 63c629a.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants