Skip to content

Commit

Permalink
Merge branch 'main' into add_mammut
Browse files Browse the repository at this point in the history
  • Loading branch information
gpucce committed Sep 19, 2023
2 parents b980b57 + f692ec9 commit dc8d841
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest-split==0.8.0
pytest==7.2.0
transformers
timm==0.6.11
timm>=0.9.5
2 changes: 1 addition & 1 deletion requirements-training.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ pandas
braceexpand
huggingface_hub
transformers
timm
timm>=0.9.5
fsspec
4 changes: 4 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
Expand Down
3 changes: 2 additions & 1 deletion src/open_clip/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def forward(self, img):
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
new_size = tuple(round(dim * scale) for dim in (height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
if not width == height:
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
Expand Down

0 comments on commit dc8d841

Please sign in to comment.