From 8556945f09cd149cbef69f8394308d8b41dca596 Mon Sep 17 00:00:00 2001 From: Jason Chou Date: Mon, 28 Aug 2023 12:31:04 -0700 Subject: [PATCH] Fix `text.transformer.embeddings.position_ids` key error (#595) * fix create_model & test * better fix: explained & strict --- src/open_clip/factory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index ac8596eab..72a4e4d18 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -100,6 +100,10 @@ def load_checkpoint(model, checkpoint_path, strict=True): # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + # Certain text transformers no longer expect position_ids after transformers==4.31 + position_id_key = 'text.transformer.embeddings.position_ids' + if position_id_key in state_dict and not hasattr(model, position_id_key): + del state_dict[position_id_key] resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys