From 64d42df4e60df1c30d6246a30e28792374fef3ab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 Sep 2023 12:50:14 -0700 Subject: [PATCH] Convert JIT model (on state dict load) to sd for pretrained='filename.pt' support for OpenAI .pt files. Fix #622 --- src/open_clip/factory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 72a4e4d18..e4f0b7632 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -88,6 +88,10 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] + elif isinstance(checkpoint, torch.jit.ScriptModule): + state_dict = checkpoint.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'):