-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CoCa model #506
Add CoCa model #506
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
…parameter to float (#510) Summary: GitHub actions for #506 show failures in `test_contrastive_loss_with_temperature.py` even though the changes in that PR do not touch any contrastive loss components. Running e.g. `python -m pytest -v tests/modules/losses/test_contrastive_loss_with_temperature.py`, even on that PR, there are no failures. But when we run the full test suite, two of the test cases in `test_contrastive_loss_with_temperature.py` fail. This is because of how we define the default value of `logit_scale` in `ContrastiveLossWithTemperature`. We set the default to an `nn.Parameter`, which is initialized the first time the class gets imported. But then this parameter is already defined outside of the test class and so we lose isolation of our test cases. The fix is to use a float as the default instead. Since this gets cast to an `nn.Parameter` on init anyways, there will be no difference from the user's perspective. But this way we isolate the parameter to an instance of the class instead of creating a global parameter on import. Tested on top of #506. Before the change: ``` python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py ... FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_local_loss - AssertionError: actual: 2.032681941986084, expected: 9.8753 FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_loss_with_ce_kwargs - AssertionError: actual: 2.1044366359710693, expected: 10.2524 ================================================================== 2 failed, 6 passed, 2 skipped in 3.00s =================================================================== ``` After the change: ``` python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py ... ======================================================================= 8 passed, 2 skipped in 2.87s ======================================================================== ``` Pull Request resolved: #510 Reviewed By: kartikayk Differential Revision: D50974788 Pulled By: ebsmothers fbshipit-source-id: 6b1c2ed98583a0efd4a41894ef7c151189d51f31
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good as an implementation for CoCA and as a tested internal model. I think it would be helpful to include more context for this model and that it was already tested and built internally in the PR description. Otherwise no issues with how it's being moved to open. I left a few nit comments that would be nice to have but not blocking.
class CoCaMultimodalDecoder(nn.Module): | ||
""" | ||
Multimodal decoder with cross-attention for CoCa model. | ||
Based on the implementation in open_clip: https://tinyurl.com/mn35vdmd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful to have a brief description of the architecture and use cases for this model here. Is the CoCA decoder a variant on a more general architecture?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is pretty much just a generic transformer decoder with cross-attention and optional output projection. I can add more detail here though
|
||
class CoCaTextEmbeddings(nn.Module): | ||
""" | ||
Text embeddings for CoCa model. Includes token embeddings, positional embeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as multimodal decoder though probably not as necessary here
class AttentionPooler(nn.Module): | ||
""" | ||
Attention pooling layer: pools inputs to sequence length n_queries by performing | ||
cross-attention with learned query embeddings. Based on the CoCa implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did CoCA invent this or just use it? It'd be better if we could point to the original abstract if this is meant as a general layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I think it is originally from the set transformer paper, will add a link.
if isinstance(vision_encoder_outs, TransformerOutput): | ||
image_embeddings = vision_encoder_outs.last_hidden_state | ||
elif isinstance(vision_encoder_outs, tuple): | ||
vision_encoder_outs = vision_encoder_outs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems very risky. In general isinstance is very risky, but here especially since the user doesn't know without looking at the code that it's assumed that output 0 is the only thing being used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah fair point. This is because a common return type from transformer encoders is a tuple (last_hidden_state, all_hidden_states). Maybe we could add an assert here? Something like
vision_encoder_outs = vision_encoder_outs[0] | |
assert len(vision_encoder_outs) == 2 and isinstance(vision_encoder_outs[0], Tensor) and isinstance(vision_encoder_outs[1], List) | |
vision_encoder_outs = vision_encoder_outs[0] |
@pbontrager thanks for the review. Agree it's probably worth to add more context; actually this is why I created #507 to demonstrate parity with the open_clip implementation on ImageNet zero-shot (though I probably should have mentioned it on this PR 😅). |
[ghstack-poisoned]
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## gh/ebsmothers/22/base #506 +/- ##
========================================================
Coverage ? 75.57%
========================================================
Files ? 234
Lines ? 16110
Branches ? 0
========================================================
Hits ? 12175
Misses ? 3935
Partials ? 0 ☔ View full report in Codecov by Sentry. |
@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@ebsmothers merged this pull request in 6433f8f. |
Stack from ghstack (oldest at bottom):
Differential Revision: D51332627