diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index ac8596eab..72a4e4d18 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -100,6 +100,10 @@ def load_checkpoint(model, checkpoint_path, strict=True): # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + # Certain text transformers no longer expect position_ids after transformers==4.31 + position_id_key = 'text.transformer.embeddings.position_ids' + if position_id_key in state_dict and not hasattr(model, position_id_key): + del state_dict[position_id_key] resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys