Skip to content

Commit

Permalink
Add logit_bias to custom text clip
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 15, 2023
1 parent f8babcf commit e48ee1a
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ def __init__(
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -287,7 +289,11 @@ def __init__(
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None

def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
Expand Down

0 comments on commit e48ee1a

Please sign in to comment.