diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index 5fedf0d982..13ed9f458e 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -905,9 +905,12 @@ def _load_pretrained_model( **kwargs, ): # Filter only weights not part of base model - head_state_dict = { - key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix) - } + if state_dict is not None: + head_state_dict = { + key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix) + } + else: + head_state_dict = None head_name = "default" loader = PredictionHeadLoader(model, error_on_missing=False, convert_to_flex_head=True) head_config, new_head_state_dict = loader.convert_static_to_flex_head(head_state_dict, load_as=head_name) @@ -919,6 +922,7 @@ def _load_pretrained_model( model.add_prediction_head_from_config(head_name, head_config, overwrite_ok=True) + if new_head_state_dict is not None: for k in head_state_dict: del state_dict[k] loaded_keys.remove(k) diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 63cb266673..25ad7e4fc1 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -772,7 +772,7 @@ def convert_static_to_flex_head(self, state_dict, load_as="default"): Loads a prediction head module from the given state dict, which contains a static head checkpoint. Args: - state_dict (dict): The static head checkpoint from which to load the head module. + state_dict (dict): The static head checkpoint from which to load the head module. Can be None. load_as (str, optional): Load the weights with this name. Defaults to None. Returns: @@ -798,8 +798,11 @@ def convert_static_to_flex_head(self, state_dict, load_as="default"): return None, None # Load head weights - new_state_dict = {} - for k, v in state_dict.items(): - new_k = conversion_rename_func(k) - new_state_dict[new_k] = v + if state_dict is not None: + new_state_dict = {} + for k, v in state_dict.items(): + new_k = conversion_rename_func(k) + new_state_dict[new_k] = v + else: + new_state_dict = None return head_config, new_state_dict