Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CoCa model #506

Closed
wants to merge 5 commits into from
Closed

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Nov 1, 2023

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 1, 2023
@ebsmothers ebsmothers marked this pull request as draft November 1, 2023 16:45
@ebsmothers ebsmothers marked this pull request as ready for review November 2, 2023 03:23
facebook-github-bot pushed a commit that referenced this pull request Nov 9, 2023
…parameter to float (#510)

Summary:
GitHub actions for #506 show failures in `test_contrastive_loss_with_temperature.py` even though the changes in that PR do not touch any contrastive loss components. Running e.g. `python -m pytest -v tests/modules/losses/test_contrastive_loss_with_temperature.py`, even on that PR, there are no failures. But when we run the full test suite, two of the test cases in `test_contrastive_loss_with_temperature.py` fail.

This is because of how we define the default value of `logit_scale` in `ContrastiveLossWithTemperature`. We set the default to an `nn.Parameter`, which is initialized the first time the class gets imported. But then this parameter is already defined outside of the test class and so we lose isolation of our test cases.

The fix is to use a float as the default instead. Since this gets cast to an `nn.Parameter` on init anyways, there will be no difference from the user's perspective. But this way we isolate the parameter to an instance of the class instead of creating a global parameter on import.

Tested on top of #506. Before the change:

```
python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py
...
FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_local_loss - AssertionError: actual: 2.032681941986084, expected: 9.8753
FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_loss_with_ce_kwargs - AssertionError: actual: 2.1044366359710693, expected: 10.2524
================================================================== 2 failed, 6 passed, 2 skipped in 3.00s ===================================================================
```

After the change:

```
python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py
...
======================================================================= 8 passed, 2 skipped in 2.87s ========================================================================
```

Pull Request resolved: #510

Reviewed By: kartikayk

Differential Revision: D50974788

Pulled By: ebsmothers

fbshipit-source-id: 6b1c2ed98583a0efd4a41894ef7c151189d51f31
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good as an implementation for CoCA and as a tested internal model. I think it would be helpful to include more context for this model and that it was already tested and built internally in the PR description. Otherwise no issues with how it's being moved to open. I left a few nit comments that would be nice to have but not blocking.

class CoCaMultimodalDecoder(nn.Module):
"""
Multimodal decoder with cross-attention for CoCa model.
Based on the implementation in open_clip: https://tinyurl.com/mn35vdmd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to have a brief description of the architecture and use cases for this model here. Is the CoCA decoder a variant on a more general architecture?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is pretty much just a generic transformer decoder with cross-attention and optional output projection. I can add more detail here though


class CoCaTextEmbeddings(nn.Module):
"""
Text embeddings for CoCa model. Includes token embeddings, positional embeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as multimodal decoder though probably not as necessary here

class AttentionPooler(nn.Module):
"""
Attention pooling layer: pools inputs to sequence length n_queries by performing
cross-attention with learned query embeddings. Based on the CoCa implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did CoCA invent this or just use it? It'd be better if we could point to the original abstract if this is meant as a general layer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think it is originally from the set transformer paper, will add a link.

if isinstance(vision_encoder_outs, TransformerOutput):
image_embeddings = vision_encoder_outs.last_hidden_state
elif isinstance(vision_encoder_outs, tuple):
vision_encoder_outs = vision_encoder_outs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very risky. In general isinstance is very risky, but here especially since the user doesn't know without looking at the code that it's assumed that output 0 is the only thing being used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah fair point. This is because a common return type from transformer encoders is a tuple (last_hidden_state, all_hidden_states). Maybe we could add an assert here? Something like

Suggested change
vision_encoder_outs = vision_encoder_outs[0]
assert len(vision_encoder_outs) == 2 and isinstance(vision_encoder_outs[0], Tensor) and isinstance(vision_encoder_outs[1], List)
vision_encoder_outs = vision_encoder_outs[0]

@ebsmothers
Copy link
Contributor Author

@pbontrager thanks for the review. Agree it's probably worth to add more context; actually this is why I created #507 to demonstrate parity with the open_clip implementation on ImageNet zero-shot (though I probably should have mentioned it on this PR 😅).

[ghstack-poisoned]
@codecov-commenter
Copy link

codecov-commenter commented Nov 15, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

❗ No coverage uploaded for pull request base (gh/ebsmothers/22/base@eb775d1). Click here to learn what that means.

Additional details and impacted files
@@                   Coverage Diff                    @@
##             gh/ebsmothers/22/base     #506   +/-   ##
========================================================
  Coverage                         ?   75.57%           
========================================================
  Files                            ?      234           
  Lines                            ?    16110           
  Branches                         ?        0           
========================================================
  Hits                             ?    12175           
  Misses                           ?     3935           
  Partials                         ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers
Copy link
Contributor Author

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ebsmothers merged this pull request in 6433f8f.

@facebook-github-bot facebook-github-bot deleted the gh/ebsmothers/22/head branch November 19, 2023 15:23
@ebsmothers ebsmothers mentioned this pull request Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants