diff --git a/test/models/flava/test_flava.py b/test/models/flava/test_flava.py index 10df6e23c..0485e133d 100644 --- a/test/models/flava/test_flava.py +++ b/test/models/flava/test_flava.py @@ -83,7 +83,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 21.5150, + 21.4029, places=4, ) @@ -107,7 +107,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 8.9674, + 8.6285, places=4, ) @@ -132,7 +132,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 10.0305, + 11.0002, places=4, ) diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index c62d4f5ba..78214c361 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -308,7 +308,10 @@ def __init__( **kwargs: Any, ): super().__init__() - + self.itm_loss = ITMLoss( + hidden_size=hidden_size, + ignore_index=ignore_index, + ) self.contrastive_loss = FLAVAGlobalContrastiveLoss( logit_scale=logit_scale, image_embedding_size=hidden_size, @@ -348,10 +351,6 @@ def __init__( ), } ) - self.itm_loss = ITMLoss( - hidden_size=hidden_size, - ignore_index=ignore_index, - ) self.mim_weight = mim_weight self.mlm_weight = mlm_weight