Skip to content

Commit

Permalink
Not working for sharded checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Oct 28, 2023
1 parent d854bc7 commit e96e0e7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
10 changes: 7 additions & 3 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,9 +905,12 @@ def _load_pretrained_model(
**kwargs,
):
# Filter only weights not part of base model
head_state_dict = {
key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix)
}
if state_dict is not None:
head_state_dict = {
key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix)
}
else:
head_state_dict = None
head_name = "default"
loader = PredictionHeadLoader(model, error_on_missing=False, convert_to_flex_head=True)
head_config, new_head_state_dict = loader.convert_static_to_flex_head(head_state_dict, load_as=head_name)
Expand All @@ -919,6 +922,7 @@ def _load_pretrained_model(

model.add_prediction_head_from_config(head_name, head_config, overwrite_ok=True)

if new_head_state_dict is not None:
for k in head_state_dict:
del state_dict[k]
loaded_keys.remove(k)
Expand Down
13 changes: 8 additions & 5 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def convert_static_to_flex_head(self, state_dict, load_as="default"):
Loads a prediction head module from the given state dict, which contains a static head checkpoint.
Args:
state_dict (dict): The static head checkpoint from which to load the head module.
state_dict (dict): The static head checkpoint from which to load the head module. Can be None.
load_as (str, optional): Load the weights with this name. Defaults to None.
Returns:
Expand All @@ -798,8 +798,11 @@ def convert_static_to_flex_head(self, state_dict, load_as="default"):
return None, None

# Load head weights
new_state_dict = {}
for k, v in state_dict.items():
new_k = conversion_rename_func(k)
new_state_dict[new_k] = v
if state_dict is not None:
new_state_dict = {}
for k, v in state_dict.items():
new_k = conversion_rename_func(k)
new_state_dict[new_k] = v
else:
new_state_dict = None
return head_config, new_state_dict

0 comments on commit e96e0e7

Please sign in to comment.