diff --git a/requirements-test.txt b/requirements-test.txt index 5d2e7e147..ae2a0563f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ pytest-split==0.8.0 pytest==7.2.0 transformers -timm==0.6.11 +timm>=0.9.5 diff --git a/requirements-training.txt b/requirements-training.txt index c44eb61d7..0f4970476 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -8,5 +8,5 @@ pandas braceexpand huggingface_hub transformers -timm +timm>=0.9.5 fsspec diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 8eb507576..cfded7b1b 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -90,6 +90,10 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] + elif isinstance(checkpoint, torch.jit.ScriptModule): + state_dict = checkpoint.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'): diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 748884a3c..ab13a21aa 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -40,9 +40,10 @@ def forward(self, img): else: width, height = img.size scale = self.max_size / float(max(height, width)) + new_size = tuple(round(dim * scale) for dim in (height, width)) if scale != 1.0: - new_size = tuple(round(dim * scale) for dim in (height, width)) img = F.resize(img, new_size, self.interpolation) + if not width == height: pad_h = self.max_size - new_size[0] pad_w = self.max_size - new_size[1] img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)