From 7b8dd2cbaf9cca13ca5b1defa6a321a145eb166c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 21 Oct 2023 14:52:57 -0700 Subject: [PATCH] Merge model_cfg & model_kwargs before passing to model, allows SigLIP models to be trained with SigLIP loss via --siglip (avoid dupe arg) --- src/open_clip/factory.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index ef94b51f8..59fa1e92f 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -240,13 +240,14 @@ def create_model( model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) if custom_text: if "multimodal_cfg" in model_cfg: - model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) + model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: - model = CustomTextCLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) + model = CLIP(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16