diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 72a4e4d18..e4f0b7632 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -88,6 +88,10 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] + elif isinstance(checkpoint, torch.jit.ScriptModule): + state_dict = checkpoint.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'):