Skip to content

Commit

Permalink
Merge branch 'main' into dev/test-refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Jan 8, 2025
2 parents be69f0a + 303c34b commit f1b1136
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 91 deletions.
8 changes: 6 additions & 2 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,17 @@ def __init__(self, *stack_layers: List[Union[AdapterCompositionBlock, str]]):


class Fuse(AdapterCompositionBlock):
def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]]):
def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]], name: Optional[str] = None):
super().__init__(*fuse_stacks)
self._name = name

# TODO-V2 pull this up to all block classes?
@property
def name(self):
return ",".join([c if isinstance(c, str) else c.last() for c in self.children])
if self._name:
return self._name
else:
return ",".join([c if isinstance(c, str) else c.last() for c in self.children])


class Split(AdapterCompositionBlock):
Expand Down
30 changes: 22 additions & 8 deletions src/adapters/configuration/model_adapters_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import logging
from collections.abc import Collection, Mapping
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from .. import __version__
from ..composition import AdapterCompositionBlock
Expand All @@ -27,6 +27,7 @@ def __init__(self, **kwargs):

self.fusions: Mapping[str, str] = kwargs.pop("fusions", {})
self.fusion_config_map = kwargs.pop("fusion_config_map", {})
self.fusion_name_map = kwargs.pop("fusion_name_map", {})

# TODO-V2 Save this with config?
self.active_setup: Optional[AdapterCompositionBlock] = None
Expand Down Expand Up @@ -131,7 +132,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None):
self.adapters[adapter_name] = config_name
logger.info(f"Adding adapter '{adapter_name}'.")

def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:
def get_fusion(self, fusion_name: Union[str, List[str]]) -> Tuple[Optional[dict], Optional[list]]:
"""
Gets the config dictionary for a given AdapterFusion.
Expand All @@ -140,6 +141,7 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:
Returns:
Optional[dict]: The AdapterFusion configuration.
Optional[list]: The names of the adapters to fuse.
"""
if isinstance(fusion_name, list):
fusion_name = ",".join(fusion_name)
Expand All @@ -149,20 +151,31 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:
config = self.fusion_config_map.get(config_name, None)
else:
config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None)

if fusion_name in self.fusion_name_map:
adapter_names = self.fusion_name_map[fusion_name]
else:
adapter_names = fusion_name.split(",")

return config, adapter_names
else:
config = None
return config
return None, None

def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None):
def add_fusion(
self, adapter_names: List[str], config: Optional[Union[str, dict]] = None, fusion_name: Optional[str] = None
):
"""
Adds a new AdapterFusion.
Args:
fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse.
adapter_names (List[str]): The names of the adapters to fuse.
config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None.
fusion_name (Optional[str], optional): The name of the AdapterFusion. If not specified, will default to comma-separated adapter names.
"""
if isinstance(fusion_name, list):
fusion_name = ",".join(fusion_name)
if fusion_name is None:
fusion_name = ",".join(adapter_names)
else:
self.fusion_name_map[fusion_name] = adapter_names
if fusion_name in self.fusions:
raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.")
if config is None:
Expand Down Expand Up @@ -218,6 +231,7 @@ def to_dict(self):
output_dict["fusion_config_map"][k] = v.to_dict()
else:
output_dict["fusion_config_map"][k] = copy.deepcopy(v)
output_dict["fusion_name_map"] = copy.deepcopy(self.fusion_name_map)
return output_dict

def __eq__(self, other):
Expand Down
12 changes: 9 additions & 3 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def save_to_state_dict(self, name: str):
if name not in self.model.adapters_config.fusions:
raise ValueError(f"No AdapterFusion with name '{name}' available.")

adapter_fusion_config = self.model.adapters_config.get_fusion(name)
adapter_fusion_config, _ = self.model.adapters_config.get_fusion(name)

config_dict = build_full_config(
adapter_fusion_config,
Expand Down Expand Up @@ -676,13 +676,14 @@ def save(self, save_directory: str, name: str, meta_dict=None):
else:
assert isdir(save_directory), "Saving path should be a directory where the head can be saved."

adapter_fusion_config = self.model.adapters_config.get_fusion(name)
adapter_fusion_config, adapter_names = self.model.adapters_config.get_fusion(name)

# Save the adapter fusion configuration
config_dict = build_full_config(
adapter_fusion_config,
self.model.config,
name=name,
adapter_names=adapter_names,
model_name=self.model.model_name,
model_class=self.model.__class__.__name__,
)
Expand Down Expand Up @@ -746,9 +747,14 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs):
config = self.weights_helper.load_weights_config(save_directory)

adapter_fusion_name = load_as or config["name"]
adapter_names = config.get("adapter_names", adapter_fusion_name)
if adapter_fusion_name not in self.model.adapters_config.fusions:
self.model.add_adapter_fusion(
adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True)
adapter_names,
config["config"],
name=adapter_fusion_name,
overwrite_ok=True,
set_active=kwargs.pop("set_active", True),
)
else:
logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name))
Expand Down
8 changes: 4 additions & 4 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:

def add_fusion_layer(self, adapter_names: Union[List, str]):
"""See BertModel.add_fusion_layer"""
adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",")
fusion_name = ",".join(adapter_names) if isinstance(adapter_names, list) else adapter_names
fusion_config, adapter_names = self.adapters_config.get_fusion(fusion_name)
if self.adapters_config.common_config_value(adapter_names, self.location_key):
fusion_config = self.adapters_config.get_fusion(adapter_names)
dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0)
fusion = BertFusion(
fusion_config,
self.model_config.hidden_size,
dropout_prob,
)
fusion.train(self.training) # make sure training mode is consistent
self.adapter_fusion_layer[",".join(adapter_names)] = fusion
self.adapter_fusion_layer[fusion_name] = fusion

def delete_fusion_layer(self, adapter_names: Union[List, str]):
adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names)
Expand Down Expand Up @@ -223,7 +223,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0
context = ForwardContext.get_context()

# config of _last_ fused adapter is significant
fusion_config = self.adapters_config.get_fusion(adapter_setup.name)
fusion_config, _ = self.adapters_config.get_fusion(adapter_setup.name)
last = adapter_setup.last()
last_adapter = self.adapters[last]
hidden_states, query, residual = last_adapter.pre_forward(
Expand Down
27 changes: 16 additions & 11 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def add_adapter_fusion(
self,
adapter_names: Union[Fuse, list, str],
config=None,
name: str = None,
overwrite_ok: bool = False,
set_active: bool = False,
):
Expand All @@ -655,29 +656,33 @@ def add_adapter_fusion(
- a string identifying a pre-defined adapter fusion configuration
- a dictionary representing the adapter fusion configuration
- the path to a file containing the adapter fusion configuration
name (str, optional):
Name of the AdapterFusion layer. If not specified, the name is generated automatically from the fused adapter names.
overwrite_ok (bool, optional):
Overwrite an AdapterFusion layer with the same name if it exists. By default (False), an exception is
thrown.
set_active (bool, optional):
Activate the added AdapterFusion. By default (False), the AdapterFusion is added but not activated.
"""
if isinstance(adapter_names, Fuse):
if name is None:
name = adapter_names.name
adapter_names = adapter_names.children
elif isinstance(adapter_names, str):
adapter_names = adapter_names.split(",")
if name is None:
name = ",".join(adapter_names)

if isinstance(config, dict):
config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
if overwrite_ok and self.adapters_config.get_fusion(adapter_names) is not None:
self.delete_adapter_fusion(adapter_names)
self.adapters_config.add_fusion(adapter_names, config=config)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names))
self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names))
if overwrite_ok and self.adapters_config.get_fusion(name)[0] is not None:
self.delete_adapter_fusion(name)
self.adapters_config.add_fusion(adapter_names, config=config, fusion_name=name)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(name))
self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(name))
if set_active:
if not isinstance(adapter_names, list):
adapter_names = adapter_names.split(",")
self.set_active_adapters(Fuse(*adapter_names))
self.set_active_adapters(Fuse(*adapter_names, name=name))

def delete_adapter(self, adapter_name: str):
"""
Expand Down Expand Up @@ -710,7 +715,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]):
adapter_names (Union[Fuse, list, str]): AdapterFusion layer to delete.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
adapter_fusion_name = adapter_names.name
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
elif isinstance(adapter_names, str):
Expand Down Expand Up @@ -776,7 +781,7 @@ def save_adapter_fusion(
ValueError: If the given AdapterFusion name is invalid.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
adapter_fusion_name = adapter_names.name
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
elif isinstance(adapter_names, str):
Expand Down Expand Up @@ -1094,7 +1099,7 @@ def save_all_adapter_fusions(
"""
os.makedirs(save_directory, exist_ok=True)
for name in self.adapters_config.fusions:
adapter_fusion_config = self.adapters_config.get_fusion(name)
adapter_fusion_config, _ = self.adapters_config.get_fusion(name)
h = get_adapter_config_hash(adapter_fusion_config)
save_path = join(save_directory, name)
if meta_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,86 @@ def test_output_adapter_fusion_attentions(self):
self.assertEqual(len(per_layer_scores), 1)
for k, v in per_layer_scores.items():
self.assertEqual(self.input_shape[0], v.shape[0], k)

def test_add_adapter_fusion_custom_name(self):
config_name = "seq_bn"
model = self.get_model()
model.eval()

name1 = f"{config_name}-1"
name2 = f"{config_name}-2"
model.add_adapter(name1, config=config_name)
model.add_adapter(name2, config=config_name)

# adapter is correctly added to config
self.assertTrue(name1 in model.adapters_config)
self.assertTrue(name2 in model.adapters_config)

# add fusion with default name
model.add_adapter_fusion([name1, name2])
model.to(torch_device)

# check forward pass
input_data = self.get_input_samples(config=model.config)
model.set_active_adapters(Fuse(name1, name2))
fusion_default_ref_output = model(**input_data)

# add fusion with custom name
model.add_adapter_fusion([name1, name2], name="custom_name_fusion")
model.to(torch_device)

self.assertIn(f"{name1},{name2}", model.adapters_config.fusions)
self.assertIn("custom_name_fusion", model.adapters_config.fusions)
self.assertIn("custom_name_fusion", model.adapters_config.fusion_name_map)

# check forward pass
model.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion"))
fusion_custom_output = model(**input_data)
model.set_active_adapters(Fuse(name1, name2))
fusion_default_output = model(**input_data)
model.set_active_adapters(None)
base_output = model(**input_data)

self.assertFalse(torch.equal(fusion_default_ref_output[0], base_output[0]))
self.assertTrue(torch.equal(fusion_default_ref_output[0], fusion_default_output[0]))
self.assertFalse(torch.equal(fusion_custom_output[0], fusion_default_output[0]))
self.assertFalse(torch.equal(fusion_custom_output[0], base_output[0]))

# delete only the custom fusion
model.delete_adapter_fusion(Fuse(name1, name2, name="custom_name_fusion"))
# model.delete_adapter_fusion("custom_name_fusion")

self.assertIn(f"{name1},{name2}", model.adapters_config.fusions)
self.assertNotIn("custom_name_fusion", model.adapters_config.fusions)

def test_load_adapter_fusion_custom_name(self):
model1 = self.get_model()
model1.eval()

name1 = "name1"
name2 = "name2"
model1.add_adapter(name1)
model1.add_adapter(name2)

model2 = copy.deepcopy(model1)
model2.eval()

model1.add_adapter_fusion([name1, name2], name="custom_name_fusion")
model1.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion"))

with tempfile.TemporaryDirectory() as temp_dir:
model1.save_adapter_fusion(temp_dir, "custom_name_fusion")
# also tests that set_active works
model2.load_adapter_fusion(temp_dir, set_active=True)

# check if adapter was correctly loaded
self.assertEqual(model1.adapters_config.fusions.keys(), model2.adapters_config.fusions.keys())

# check equal output
in_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
output1 = model1(**in_data)
output2 = model2(**in_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))
63 changes: 0 additions & 63 deletions utils/rename_script.py

This file was deleted.

0 comments on commit f1b1136

Please sign in to comment.