diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 93161a850..bb787e852 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -278,6 +278,8 @@ def __init__( vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): @@ -287,7 +289,11 @@ def __init__( self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + else: + self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991