diff --git a/torchmultimodal/modules/losses/contrastive_loss_with_temperature.py b/torchmultimodal/modules/losses/contrastive_loss_with_temperature.py index 58a3be447..155445a0d 100644 --- a/torchmultimodal/modules/losses/contrastive_loss_with_temperature.py +++ b/torchmultimodal/modules/losses/contrastive_loss_with_temperature.py @@ -115,7 +115,7 @@ def contrastive_loss_with_temperature( ) -DEFAULT_LOGIT_SCALE = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) +DEFAULT_LOGIT_SCALE = math.log(1 / 0.07) class ContrastiveLossWithTemperature(nn.Module):