Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change default logit scale in contrastive loss with temperature from parameter to float #510

Closed
wants to merge 1 commit into from

Conversation

ebsmothers
Copy link
Contributor

GitHub actions for #506 show failures in test_contrastive_loss_with_temperature.py even though the changes in that PR do not touch any contrastive loss components. Running e.g. python -m pytest -v tests/modules/losses/test_contrastive_loss_with_temperature.py, even on that PR, there are no failures. But when we run the full test suite, two of the test cases in test_contrastive_loss_with_temperature.py fail.

This is because of how we define the default value of logit_scale in ContrastiveLossWithTemperature. We set the default to an nn.Parameter, which is initialized the first time the class gets imported. But then this parameter is already defined outside of the test class and so we lose isolation of our test cases.

The fix is to use a float as the default instead. Since this gets cast to an nn.Parameter on init anyways, there will be no difference from the user's perspective. But this way we isolate the parameter to an instance of the class instead of creating a global parameter on import.

Tested on top of #506. Before the change:

python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py
...
FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_local_loss - AssertionError: actual: 2.032681941986084, expected: 9.8753
FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_loss_with_ce_kwargs - AssertionError: actual: 2.1044366359710693, expected: 10.2524
================================================================== 2 failed, 6 passed, 2 skipped in 3.00s ===================================================================

After the change:

python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py
...
======================================================================= 8 passed, 2 skipped in 2.87s ========================================================================

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 2, 2023
@codecov-commenter
Copy link

codecov-commenter commented Nov 3, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (a33a8b8) 74.94% compared to head (e525590) 74.94%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #510      +/-   ##
==========================================
- Coverage   74.94%   74.94%   -0.01%     
==========================================
  Files         226      226              
  Lines       15592    15592              
==========================================
- Hits        11686    11685       -1     
- Misses       3906     3907       +1     
Files Coverage Δ
...odules/losses/contrastive_loss_with_temperature.py 88.67% <100.00%> (-1.89%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the change! I'm glad we're fixing this, the nn.Parameter default value doesnt make sense to me. There are some other issues with this class from my POV - why do we accept nn.Parameters in the init function at all? In my view these should be self contained within the class. We should just force the user to send a float and then add these as params in this class (unless I misunderstand the setup). But that can likely come in a follow up PR at a later time. If you agree, maybe worth opening an issue and adding to the log.

@facebook-github-bot
Copy link
Contributor

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ebsmothers merged this pull request in e6b92b5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants