diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 0372c48b28..d7387442e6 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -82,6 +82,8 @@ LNTuningModel, VeraConfig, VeraModel, + XLoraConfig, + XLoraModel, ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 77b6b66eea..26f7f579f8 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -19,6 +19,8 @@ import torch +from peft.tuners.xlora.model import XLoraModel + from .config import PeftConfig from .mixed_model import PeftMixedModel from .peft_model import ( @@ -56,6 +58,7 @@ PromptTuningConfig, VeraConfig, VeraModel, + XLoraConfig, ) from .tuners.tuners_utils import BaseTuner as _BaseTuner from .utils import _prepare_prompt_learning_config @@ -90,6 +93,7 @@ "POLY": PolyConfig, "LN_TUNING": LNTuningConfig, "VERA": VeraConfig, + "XLORA": XLoraConfig, } PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = { @@ -103,6 +107,7 @@ "POLY": PolyModel, "LN_TUNING": LNTuningModel, "VERA": VeraModel, + "XLORA": XLoraModel, } diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 28c7ab5683..0a1779d3fd 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -23,8 +23,6 @@ from torch import nn from transformers.utils import PushToHubMixin -from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES - from .config import PeftConfig from .peft_model import PeftModel from .tuners import ( @@ -36,6 +34,7 @@ MixedModel, OFTModel, ) +from .tuners.mixed import COMPATIBLE_TUNER_TYPES from .utils import PeftType, _set_adapter, _set_trainable diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 370f871524..4dd2b4ee79 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -29,7 +29,7 @@ from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory, named_module_tensors -from huggingface_hub import ModelCard, ModelCardData, hf_hub_download +from huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download from safetensors import safe_open from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -55,6 +55,8 @@ PromptEmbedding, PromptEncoder, VeraModel, + XLoraConfig, + XLoraModel, ) from .tuners.tuners_utils import BaseTuner, BaseTunerLayer from .utils import ( @@ -91,6 +93,7 @@ PeftType.POLY: PolyModel, PeftType.LN_TUNING: LNTuningModel, PeftType.VERA: VeraModel, + PeftType.XLORA: XLoraModel, } @@ -479,6 +482,30 @@ def from_pretrained( raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: config.inference_mode = not is_trainable + if isinstance(getattr(model, "base_model", None), XLoraModel): + if not isinstance(config, XLoraConfig): + raise TypeError(f"Expected 'XLoraConfig', got '{type(config)}' instead.") + if "adapters" in kwargs: + config.adapters = kwargs["adapters"] + else: + # If the path is on HF hub, then we get the adapter names to create a subfolders list which tells + # `load_adapter` where the adapters are. + if not os.path.exists(model_id): + s = HfFileSystem() + + # The names of the adapters which must be in folders + adapter_names = [ + file["name"][len(model_id) + 1 :] for file in s.ls(model_id) if file["type"] == "directory" + ] + # Prepare a dict of adapter paths, which really just point to the hf id; we will use the subfolders + adapter_paths = {} + for adapter_name in adapter_names: + adapter_paths[adapter_name] = os.path.join(model_id, model_id) + config.adapters = adapter_paths + config._subfolders = adapter_names + else: + if "adapters" not in kwargs: + raise ValueError("If model_id is a local path, then `adapters` must be passed in kwargs.") if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype) @@ -486,6 +513,7 @@ def from_pretrained( model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type]( model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype ) + model.load_adapter( model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs ) diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index c5beb67493..06db148470 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -33,3 +33,4 @@ from .poly import PolyConfig, PolyModel from .ln_tuning import LNTuningConfig, LNTuningModel from .vera import VeraConfig, VeraModel +from .xlora import XLoraConfig, XLoraModel diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 3a1af15bc2..97e5cfbe90 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -30,6 +30,8 @@ from transformers.pytorch_utils import Conv1D from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND +from peft.utils.constants import DUMMY_TARGET_MODULES +from peft.utils.peft_types import PeftType from ..config import PeftConfig from ..utils import ModulesToSaveWrapper, _get_submodules @@ -141,7 +143,12 @@ class BaseTuner(nn.Module, ABC): double-check that the `config.target_modules` were specified correctly. """ - def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], adapter_name: str) -> None: + def __init__( + self, + model, + peft_config: Union[PeftConfig, dict[str, PeftConfig]], + adapter_name: str, + ) -> None: super().__init__() self.model = model @@ -164,7 +171,8 @@ def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], self.active_adapter: str | list[str] = adapter_name self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name) - self.inject_adapter(self.model, adapter_name) + if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA: + self.inject_adapter(self.model, adapter_name) # Copy the peft_config in the injected model. self.model.peft_config = self.peft_config @@ -389,6 +397,11 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d is_target_modules_in_base_model = False key_list = [key for key, _ in model.named_modules()] + if getattr(peft_config, "target_modules", None) == DUMMY_TARGET_MODULES: + # dummy adapter, we allow not matching any module + key_list = [] + is_target_modules_in_base_model = True + # update peft_config.target_modules if required peft_config = _maybe_include_all_linear_layers(peft_config, model) @@ -417,7 +430,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d parent, target, target_name = _get_submodules(model, key) self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) - if not is_target_modules_in_base_model: + # Handle X-LoRA case. + if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"): raise ValueError( f"Target modules {peft_config.target_modules} not found in the base model. " f"Please check the target modules and try again." @@ -776,6 +790,8 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module) Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py """ + if not hasattr(peft_config, "target_modules"): + return peft_config # if `target_modules` is a string, convert to lower case and check if it matches "all-linear" if not ( diff --git a/src/peft/tuners/xlora/__init__.py b/src/peft/tuners/xlora/__init__.py new file mode 100644 index 0000000000..df41e1e611 --- /dev/null +++ b/src/peft/tuners/xlora/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import XLoraConfig +from .model import XLoraModel + + +__all__ = ["XLoraConfig", "XLoraModel"] diff --git a/src/peft/tuners/xlora/classifier.py b/src/peft/tuners/xlora/classifier.py new file mode 100644 index 0000000000..dffebba08a --- /dev/null +++ b/src/peft/tuners/xlora/classifier.py @@ -0,0 +1,195 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import builtins +from typing import Optional, Union + +import torch +import torch.nn as nn + +from .config import XLoraConfig + + +Number = Union[builtins.int, builtins.float, builtins.bool] + + +class TemperatureScaledSoftmax(nn.Module): + def __init__(self, temperature=1.0): + super().__init__() + self.temperature = temperature + self.softmax = nn.Softmax(dim=-1) + + def forward(self, logits): + # Scale logits by the temperature + scaled_logits = logits / self.temperature + # Apply softmax to the scaled logits + return self.softmax(scaled_logits) + + +class XLoraClassifier(nn.Module): + """ + A classifier to select LoRA layers for XLora. + """ + + def __init__( + self, + model: nn.Module, # PeftModel + config: XLoraConfig, + n_classes: int, + n_layers: int, + device: torch.device, + ): + """ + Construct an X-LoRA classifier from a model, config and some metadata. Note that n_layers is the number of LoRA + adapter layers, not the number of model layers. + """ + super().__init__() + + self.n_classes = n_classes + self.n_layers = n_layers + self.config = config + self.log_scalings = [] + self.softmax = TemperatureScaledSoftmax(temperature=self.config.softmax_temperature) + self.override_scaling_pass_value: Number = config.scaling_pass_value + + self.scalings_logging = False + + dtype = next(model.parameters()).dtype + add_dropout = config.xlora_dropout_p > 0.0 + + layers = [] + if self.config.xlora_depth == 1: + if config.layerwise_scalings: # bias=False if we have just one layer + last = nn.Linear(config.hidden_size, n_classes * n_layers, bias=True).to(device).to(dtype) + else: + last = nn.Linear(config.hidden_size, n_classes, bias=True).to(device).to(dtype) + else: + if self.config.xlora_depth <= 0: + raise ValueError("X-LoRA depth must be strictly positive.") + + layers.append(nn.Linear(config.hidden_size, config.xlora_size, bias=True).to(device).to(dtype)) + + layers.append(nn.ReLU()) + if add_dropout: + layers.append(nn.Dropout(p=config.xlora_dropout_p)) + + for _ in range(config.xlora_depth - 2): + layers.append(nn.Linear(config.xlora_size, config.xlora_size, bias=True).to(device).to(dtype)) + + layers.append(nn.ReLU()) + if add_dropout: + layers.append(nn.Dropout(p=config.xlora_dropout_p)) + + if config.layerwise_scalings: + last = nn.Linear(config.xlora_size, n_classes * n_layers, bias=True).to(device).to(dtype) + else: + last = nn.Linear(config.xlora_size, n_classes, bias=True).to(device).to(dtype) + self.layers = nn.Sequential(*layers, last) + + def make_dummy_scalings( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Make some dummy scalings for the scalings pass (the one to get the logits for the X-LoRA classifier). These are + of shape (batch_size, seq_len, n_layers, n_classes) and filled with the override scalings pass value. Note that + n_layers is the number of LoRA adapter layers, not the number of model layers. + """ + if input_ids is not None: + batch_size = input_ids.shape[0] + device = input_ids.device + seq_len = input_ids.shape[1] + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + seq_len = inputs_embeds.shape[1] + + return torch.full( # type: ignore + (batch_size, seq_len, self.n_layers, self.n_classes), + self.override_scaling_pass_value, + ).to(device) + + def forward( + self, + result, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Using the hidden states of the model, predict `n_classes` LoRA alpha values. Returns the scalings. + """ + if input_ids is not None: + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + else: + batch_size = inputs_embeds.shape[0] + seq_len = inputs_embeds.shape[1] + + hidden_states = result.hidden_states # type: ignore + + hidden_state = hidden_states[-1] # Get the last hidden state + + ### Classifier run + # hidden_state=[batch_size, seq_len, hidden_size] + logits = self.layers.forward(hidden_state) + + ### Repeat to make layerwise scalings + ### If layerwise_scalings=False, then the classifier only outputs logits which are not layer-wise. + ### So, we expand them to the correct shape. + if not self.config.layerwise_scalings: + logits = logits.unsqueeze(2) + logits = logits.expand(-1, -1, self.n_layers, -1) + + ### Classifier run + + scalings = logits.reshape(batch_size, seq_len, self.n_layers, self.n_classes) + # scalings = [batch_size, seq_len, n_layers, n_classes] + + if self.config.enable_softmax: + scalings = self.softmax(scalings) + + if self.scalings_logging: + self.log_scalings.append(scalings) + + return scalings + + def _get_bucketed_scalings(self) -> dict[int, tuple[list[int], list[torch.Tensor]]]: + """ + Returns bucketed scalings, bucketed by seq_len. Each value consists of the positions (the first) and the + associated tensors. The positions are paired with the associated tensors and give the position in the scaling + log. Each scaling is a tensor of shape (batch_size, seq_len, n_layers, n_classes)). + """ + seqlens_map: dict[int, tuple[list[int], list[torch.Tensor]]] = {} + for i, scaling in enumerate(self.log_scalings): + seq_len = scaling.shape[1] + if seq_len not in seqlens_map: + seqlens_map[seq_len] = ([i], [scaling]) + else: + seqlens_map[seq_len][0].append(i) + seqlens_map[seq_len][1].append(scaling) + + return seqlens_map + + def _set_override_scaling_pass_value(self, value: Union[Number, None]): + if value is None: + self.override_scaling_pass_value = 1 / self.n_classes + else: + self.override_scaling_pass_value = value + self.config.scaling_pass_value = self.override_scaling_pass_value diff --git a/src/peft/tuners/xlora/config.py b/src/peft/tuners/xlora/config.py new file mode 100644 index 0000000000..0ca9c5e45d --- /dev/null +++ b/src/peft/tuners/xlora/config.py @@ -0,0 +1,101 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import Optional + +from peft.config import PeftConfig +from peft.utils.peft_types import PeftType + + +@dataclass +class XLoraConfig(PeftConfig): + r""" + This is the configuration class to store the configuration of a `XLoraModel`. When the config is reloaded, the + paths of the `adapters` field is disregarded in favor of the saved adapters. As such, only the keys matter during + loading. + + Args: + hidden_size (`int`): + Hidden size of the base model. + adapters (`dict`): + Mapping of adapter names to the LoRA adapter id, as per PeftModel.load_adapter. *They will be automatically + loaded*, to use as LoRA experts. When using from_pretrained, pass the new adapters dict as a keyword + argument. + enable_softmax (`bool`, *optional*, defaults to `True`): + Enable softmax application for the X-LoRA classifier. + enable_softmax_topk (`bool`, *optional*, defaults to `False`): + Enable softmax application for the top-k LoRA adapters. Mutually exclusive to `enable_softmax` and must + only be set if `top_k_lora` is. + softmax_temperature (`float`, *optional*, defaults to 1.0): + Softmax temperature, lower yields sharper predictions + layerwise_scalings (`bool`, *optional*, defaults to `False`): + If True, generate scalings for each LoRA adapter (each layer). If this is False, then scalings will be + broadcasted, the same, to each layer. + top_k_lora (`int`, *optional*, defaults to None): + Sparsely select the top_k LoRA experts instead of the default dense method. + xlora_depth (`int`, *optional*, defaults to 1): + Depth of the X-LoRA classifier. + xlora_size (`int`, *optional*, defaults to 2048): + Hidden size of the X-LoRA classifier, irrelevant if `xlora_depth=1`. + xlora_dropout_p (`float`, *optional*, defaults to 0.2): + Dropout probability of the X-LoRA classifier, irrelevant if `xlora_depth=1`. + use_trainable_adapters (`bool`, *optional*, defaults to False): + Make the adapters trainable. + scaling_pass_value (`float`, *optional*, defaults to 0): + Scaling pass value. + global_scaling_weight (`float`, *optional*, defaults to 1): + Weight to multiply output of each LoRA adapter by. + """ + + hidden_size: int = None # type: ignore + adapters: dict[str, str] = None # type: ignore + enable_softmax: bool = True + enable_softmax_topk: bool = False + layerwise_scalings: bool = False + xlora_depth: int = 1 + xlora_size: int = 2048 + xlora_dropout_p: float = 0.2 + use_trainable_adapters: bool = False + softmax_temperature: float = 1.0 + top_k_lora: Optional[int] = None + scaling_pass_value: float = 0.0 + global_scaling_weight: float = 1.0 + + def __post_init__(self): + self.peft_type = PeftType.XLORA + + if self.hidden_size is None: + warnings.warn( + "No value was provided for `hidden_size`. This will be set to 4096 by default, please ensure that this is correct." + ) + self.hidden_size = 4096 + if self.adapters is None: + warnings.warn( + "No value was provided for for `adapters`. This will be set to empty, please ensure that this is correct." + ) + self.adapters = {} + + if self.enable_softmax_topk and self.top_k_lora is None: + warnings.warn("`enable_softmax_topk` enabled `top_k_lora` is not set") + + if self.enable_softmax_topk and self.enable_softmax: + warnings.warn( + "`enable_softmax_topk` and `enable_softmax` are both enabled. This will result in worse performance." + ) + + if self.top_k_lora is not None and self.top_k_lora < 1: + warnings.warn("`top_k_lora` value must be at least 1.") diff --git a/src/peft/tuners/xlora/layer.py b/src/peft/tuners/xlora/layer.py new file mode 100644 index 0000000000..a5035456d4 --- /dev/null +++ b/src/peft/tuners/xlora/layer.py @@ -0,0 +1,223 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from peft.tuners import lora + +from .config import XLoraConfig + + +class XLoraLayer: + """ + A XLoraLayer wraps any LoraLayer and performs the XLora operation on the LoRA adaptors specified. Its primary API + is the forward method, which uses the scalings to execute the XLora algorithm. + """ + + def __init__( + self, + model: nn.Module, # XLoraModel + target: lora.LoraLayer, + target_forward: Callable[..., Any], + layer_number: int, + config: XLoraConfig, + ) -> None: + self.model = model + self.target_forward = target_forward + self.target = target + self.layer_number = layer_number + self.config = config + + """ + Apply the scalings for the adapter. + """ + + @staticmethod + def apply_scalings_to_x(x: torch.Tensor, scalings_layer: torch.Tensor, adapter: int) -> torch.Tensor: + # scalings_layer = [batch_size, seq_len, n_classes] + scalings = scalings_layer[:, :, adapter].unsqueeze(-1) + # scalings_layer = [batch_size, seq_len, 1] + return x * scalings + + """ + Get the scalings for this layer, potentially applying topk and topk+softmax. This is called before + `apply_scalings_to_x` + """ + + def get_maybe_topk_scalings(self, scalings) -> torch.Tensor: + # xlora_scalings = [batch_size, seq_len, n_classes] + xlora_scalings: Tensor = scalings[:, :, self.layer_number, :] # type: ignore + + if self.config.top_k_lora is not None: + _, topk_indices = torch.topk(xlora_scalings, k=self.config.top_k_lora, dim=-1) + + # Mask the topk to True, the rest to False + mask = torch.zeros_like(xlora_scalings, dtype=torch.bool) + mask.scatter_(-1, topk_indices, True) + + xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype) + + if self.config.enable_softmax_topk: + nonzero_mask = xlora_scalings != 0 + softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1) + xlora_scalings[nonzero_mask] = softmax_res_nonzero + + return xlora_scalings + + +class XLoraLinearLayer(XLoraLayer): + def __init__( + self, + model: nn.Module, + target: lora.Linear, + target_forward: Callable[..., Any], + layer_number: int, + config: XLoraConfig, + ) -> None: + super().__init__(model, target, target_forward, layer_number, config) + + def forward(self, x: Tensor, *args: Any, scalings: Optional[Tensor] = None, **kwargs: Any) -> Tensor: + """ + This method is designed to be a drop-in-replacement for the LoRA layers' .forward method. To use it, a bound + method must be created (bound to an instance of the XLoraLayer class). + """ + + previous_dtype = x.dtype + if scalings is not None: + xlora_scalings = self.get_maybe_topk_scalings(scalings) + + result = self.target.base_layer(x, *args, **kwargs) + + # Ignore if disabled. We want to make sure this is always run. + if not self.target.merged: + for adapter_n, active_adapter in enumerate(self.target.active_adapters): + # TODO: implement X-LoRA with Lora+Dora layers + if self.target.use_dora[active_adapter]: + raise ValueError("X-LoRA currently does not support LoRA layers with DoRA") + if active_adapter not in self.target.lora_A.keys(): + continue + lora_A = self.target.lora_A[active_adapter] + lora_B = self.target.lora_B[active_adapter] + dropout = self.target.lora_dropout[active_adapter] + scaling = self.target.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) # type: ignore + if scalings is not None: + x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n) + scaling_weight = self.config.global_scaling_weight + else: + x_mod = x + scaling_weight = 1 + result += lora_B(lora_A(dropout(x_mod))) * scaling * scaling_weight + + result = result.to(previous_dtype) + return result + + +class XLoraEmbeddingLayer(XLoraLayer): + def __init__( + self, + model: nn.Module, + target: lora.Embedding, + target_forward: Callable[..., Any], + layer_number: int, + config: XLoraConfig, + ) -> None: + super().__init__(model, target, target_forward, layer_number, config) + + def forward(self, x: Tensor, *args: Any, scalings: Optional[Tensor] = None, **kwargs: Any) -> Tensor: + """ + This method is designed to be a drop-in-replacement for the LoRA layers' .forward method. To use it, a bound + method must be created (bound to an instance of the XLoraLayer class). + """ + + if scalings is not None: + xlora_scalings = self.get_maybe_topk_scalings(scalings) + + result = self.target.base_layer(x, *args, **kwargs) + + # Ignore if disabled. We want to make sure this is always run. + if not self.target.merged: + for adapter_n, active_adapter in enumerate(self.target.active_adapters): + # TODO: implement X-LoRA with Lora+Dora layers + if self.target.use_dora.get(active_adapter, False): + raise ValueError("X-LoRA currently does not support LoRA layers with DoRA") + if active_adapter not in self.target.lora_embedding_A: + continue + embedding_A = self.target.lora_embedding_A[active_adapter].T + embedding_B = self.target.lora_embedding_B[active_adapter].T + scaling = self.target.scaling[active_adapter] + after_A = self.target._embed(x, embedding_A) # type: ignore + if scalings is not None: + after_A_mod = self.apply_scalings_to_x(after_A, xlora_scalings, adapter_n) + scaling_weight = self.config.global_scaling_weight + else: + after_A_mod = after_A + scaling_weight = 1 + result += (after_A_mod @ embedding_B) * scaling * scaling_weight + + return result + + +class XLoraConv2dLayer(XLoraLayer): + def __init__( + self, + model: nn.Module, + target: lora.Conv2d, + target_forward: Callable[..., Any], + layer_number: int, + config: XLoraConfig, + ) -> None: + super().__init__(model, target, target_forward, layer_number, config) + + def forward(self, x: Tensor, *args: Any, scalings: Optional[Tensor] = None, **kwargs: Any) -> Tensor: + """ + This method is designed to be a drop-in-replacement for the LoRA layers' .forward method. To use it, a bound + method must be created (bound to an instance of the XLoraLayer class). + """ + + previous_dtype = x.dtype + + if scalings is not None: + xlora_scalings = self.get_maybe_topk_scalings(scalings) + + result = self.target.base_layer(x, *args, **kwargs) + + # Ignore if disabled. We want to make sure this is always run. + if not self.target.merged: + for adapter_n, active_adapter in enumerate(self.target.active_adapters): + # TODO: implement X-LoRA with Lora+Dora layers + if self.target.use_dora[active_adapter]: + raise ValueError("X-LoRA currently does not support LoRA layers with DoRA") + if active_adapter not in self.target.lora_A.keys(): + continue + lora_A = self.target.lora_A[active_adapter] + lora_B = self.target.lora_B[active_adapter] + dropout = self.target.lora_dropout[active_adapter] + scaling = self.target.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) # type: ignore + if scalings is not None: + x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n) + scaling_weight = self.config.global_scaling_weight + else: + x_mod = x + scaling_weight = 1 + result += lora_B(lora_A(dropout(x_mod))) * scaling * scaling_weight + + result = result.to(previous_dtype) + return result diff --git a/src/peft/tuners/xlora/model.py b/src/peft/tuners/xlora/model.py new file mode 100644 index 0000000000..b71065164c --- /dev/null +++ b/src/peft/tuners/xlora/model.py @@ -0,0 +1,404 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +from contextlib import contextmanager +from functools import partial +from typing import Optional, Union + +import torch +import torch.nn as nn + +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.lora.model import LoraModel +from peft.tuners.tuners_utils import BaseTuner +from peft.utils.constants import DUMMY_TARGET_MODULES + +from .. import lora +from .classifier import XLoraClassifier +from .config import XLoraConfig +from .layer import XLoraConv2dLayer, XLoraEmbeddingLayer, XLoraLinearLayer + + +def convert_layers_to_xlora( + base: nn.Module, # PeftModel + xloramodel: nn.Module, # XLoraModel + config: XLoraConfig, +) -> tuple[int, torch.device | None]: + """ + Returns the number of swapped layers. + """ + total_swapped = 0 + all_layers = [] + + device = None + for module in base.modules(): + # Check the exact type because classes like OPTLearnedPositionalEmbedding inherit from nn.Embedding + if type(module) == lora.Linear: + device = module.lora_A[next(iter(module.lora_A))].weight.device + new_layer = XLoraLinearLayer( + model=xloramodel, + target=module, + target_forward=module.forward, + layer_number=total_swapped, + config=config, + ) + all_layers.append(new_layer) + module.forward = new_layer.forward # type: ignore[method-assign] + total_swapped += 1 + elif type(module) == lora.Embedding: + device = module.lora_embedding_A[next(iter(module.lora_embedding_A))].device + new_layer = XLoraEmbeddingLayer( + model=xloramodel, + target=module, + target_forward=module.forward, + layer_number=total_swapped, + config=config, + ) + all_layers.append(new_layer) + module.forward = new_layer.forward # type: ignore[method-assign] + total_swapped += 1 + elif type(module) == lora.Conv2d: + device = module.lora_A[next(iter(module.lora_A))].weight.device + new_layer = XLoraConv2dLayer( + model=xloramodel, + target=module, + target_forward=module.forward, + layer_number=total_swapped, + config=config, + ) + all_layers.append(new_layer) + module.forward = new_layer.forward # type: ignore[method-assign] + total_swapped += 1 + + return (total_swapped, device) + + +class XLoraModel(BaseTuner): + """ + Creates an X-LoRA (Mixture of LoRA experts), model from a pretrained transformers model. Currently, this X-LoRA + implementation only works with models with a transformer architecture. + + The method is described in detail in https://arxiv.org/abs/2402.07148. + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`XLoraConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, does not affect the LoRA adapter names. + + Returns: + `torch.nn.Module`: The X-LoRA model. + + Example: + ```py + >>> from transformers import AutoModelForCausalLM, AutoConfig + >>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_int8_training + + >>> model_config = AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + >>> config = XLoraConfig( + ... task_type="CAUSAL_LM", + ... hidden_size=model_config.hidden_size, + ... xlora_depth=4, + ... adapters={ + ... "adapter_1": "./path/to/the/checkpoint/", + ... "adapter_2": "./path/to/the/checkpoint/", + ... "adapter_n": "./path/to/the/checkpoint/", + ... }, + ... ) + + >>> model = AutoModelForCausalLM.from_pretrained( + ... "mistralai/Mistral-7B-Instruct-v0.1", + ... trust_remote_code=True, + ... use_flash_attention_2=False, + ... device_map="cuda:0", + ... torch_dtype=torch.bfloat16, + ... ) + >>> model = prepare_model_for_int8_training(model) + >>> xlora_model = get_peft_model(model, config) + ``` + """ + + def __init__( + self, + model: nn.Module, + config: Union[dict[str, XLoraConfig], XLoraConfig], + adapter_name: str, + ) -> None: + nn.Module.__init__(self) + + if isinstance(config, dict): + conf = config[adapter_name] + else: + conf = config + + # Create an empty LoraModel + base_lora_config = copy.copy(conf) + base_lora_config.target_modules = DUMMY_TARGET_MODULES + # Imitate a LoraConfig, fields might need to be updated if LoraConfig is updated + base_lora_config.layer_replication = None + base_lora_config.bias = "none" + lora_model = LoraModel(model, base_lora_config, adapter_name) + + self.xlora_config = conf + self.lora_model = lora_model + + peft_config = conf + + if hasattr(model.config, "use_cache") and model.config.use_cache: + raise ValueError("`use_cache` must be False") + + adapters_items = peft_config.adapters.items() + if hasattr(self.xlora_config, "_subfolders"): + adapters_items = zip(peft_config.adapters.items(), self.xlora_config._subfolders) + else: + adapters_items = peft_config.adapters.items() + + if hasattr(self.xlora_config, "_subfolders"): + for (adapter_name, model_id), subfolder in adapters_items: + self.lora_model.load_adapter(model_id, adapter_name, subfolder=subfolder) + else: + for adapter_name, model_id in adapters_items: + self.lora_model.load_adapter(model_id, adapter_name) + + self.lora_model.set_adapter(list(peft_config.adapters.keys())) + + self._maybe_freeze_all_adapters() + + total_swapped, device = convert_layers_to_xlora( + model, + self, + peft_config, + ) + + n_classes = len(peft_config.adapters) + xlora_classifier = XLoraClassifier(model, peft_config, n_classes, total_swapped, device) + + # Setup the model internal state + self.internal_xlora_classifier = xlora_classifier + self.internal_xlora_scalings = None # type: ignore + # Controlled by enable_adapter_layers or disable_adapter_layers + self.disabled = False + + def _maybe_freeze_all_adapters(self): + self.eval() + if not self.xlora_config.use_trainable_adapters: + for name, param in self.named_parameters(): + if "lora_" in name: + param.requires_grad = False + + def generate(self, *args, **kwargs): + res = self.lora_model.generate(*args, **kwargs) # type: ignore + # This is necessary because we use PeftModel.disable_adapter() which reenables the adapters + self._maybe_freeze_all_adapters() + return res + + @contextmanager + def _enable_peft_forward_hooks(self, *generate_args, **generate_kwargs): + def scalings_injection_hook(target, args, kwargs, scalings): + # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference + kwargs["scalings"] = scalings + return args, kwargs + + handles_to_remove = None + + def pre_forward(module, *args, **kwargs): + nonlocal handles_to_remove + + # =========================== Forward pass with "dummy" scalings ================== + + args_real = args[0] + kwargs_real = args[1] + kwargs_real.update(kwargs) + + dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real) + + hook_handles = [] + for module in self.modules(): + if isinstance(module, LoraLayer): + pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + with torch.no_grad(): + self.lora_model.disable_adapters() + + try: + scaling_pass_kwargs = kwargs_real.copy() + scaling_pass_kwargs["output_hidden_states"] = True + scaling_pass_kwargs["return_dict"] = True + try: + base_output = self.lora_model.model.forward(*args_real, **scaling_pass_kwargs) + finally: + # Clean everything up + for handle in hook_handles: + handle.remove() + finally: + self.lora_model.enable_adapters() + + xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real) + + # =========================== Real forward pass with calculated scalings ================== + + hook_handles = [] + for module in self.modules(): + if isinstance(module, LoraLayer): + pre_forward = partial(scalings_injection_hook, scalings=xlora_scalings) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + handles_to_remove = hook_handles + + if not self.disabled: + forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True) + + # Run the forward pass: first the scaling pass in the hook, and then with the base model + yield + + if not self.disabled: + # TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks. + for handle in handles_to_remove: + handle.remove() + forward_handle.remove() + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "lora_model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.lora_model, name) + + @staticmethod + def _prepare_adapter_config(peft_config, _model_config): + # Handle X-LoRA case + return peft_config + + """ + Does nothing. X-LoRA needs adapters to be frozen. + """ + + def _mark_only_adapters_as_trainable(self) -> None: ... + + """ + This enables the X-LoRA adapter. + """ + + def enable_adapter_layers(self) -> None: + self.disabled = False + + """ + This diasables the X-LoRA adapter. + """ + + def disable_adapter_layers(self) -> None: + self.disabled = True + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + # Does nothing because XLoraModel has no target modules + pass + + @staticmethod + def _check_target_module_exists(lora_config, key): + # Does nothing because XLoraModel has no target modules + return False + + def forward(self, *args, **kwargs): + return self.lora_model.model(*args, **kwargs) + + def set_topk_lora(self, value: Optional[int]): + """ + Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. + This is reflected in the config. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + classifier.config.top_k_lora = value + + def set_global_scaling_weight(self, weight: float): + """ + Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is by default 1. This + is reflected in the config. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + classifier.config.global_scaling_weight = weight + + def set_scaling_pass_value(self, value: float | None): + """ + Set the scaling pass value, the value to set the scalings to during the scaling pass. If the value is None, the + scaling pass value will be 1/n where n is the number of adapters. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + classifier._set_override_scaling_pass_value(value) + + def get_global_scaling_weight(self) -> float: + """ + Get the global LoRA weight. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + return classifier.config.global_scaling_weight + + def get_latest_scalings(self) -> Optional[torch.Tensor]: + """ + Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape + (batch_size, seq_len, n_layers, n_classes). + """ + return self.internal_xlora_scalings + + def get_scalings_log(self) -> list[torch.Tensor]: + """ + Returns a shallow (only copying the list itself not the tensors) copy of the list containing the scalings log. + Editing the list does not change the underlying log. The tensors are of shape (batch_size, seq_len, n_layers, + n_classes). The seq_len dim may vary with input dimension. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + return classifier.log_scalings.copy() + + def enable_scalings_logging(self): + """ + Enable scalings logging. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + classifier.scalings_logging = True + + def disable_scalings_logging(self): + """ + Disable scalings logging, without clearing the log. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + classifier.scalings_logging = False + + def clear_scalings_log(self): + """ + Clear the scalings log. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + classifier.log_scalings.clear() + + def get_bucketed_scalings_log(self) -> dict[int, tuple[list[int], list[torch.Tensor]]]: + """ + Returns bucketed scalings, bucketed by seq_len. Each value consists of the positions (the first) and the + associated tensors. The positions are paired with the associated tensors and give the position in the scaling + log. + """ + classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore + return classifier._get_bucketed_scalings() diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 98df496275..40441f5850 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -213,3 +213,4 @@ def starcoder_model_postprocess_past_key_value(past_key_values): EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"] INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear" TOKENIZER_CONFIG_NAME = "tokenizer_config.json" +DUMMY_TARGET_MODULES = "dummy-target-modules" diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 382521d619..3c91351cbe 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -36,6 +36,7 @@ class PeftType(str, enum.Enum): - LOHA - LOKR - OFT + - XLORA - POLY - LN_TUNING """ @@ -55,6 +56,7 @@ class PeftType(str, enum.Enum): POLY = "POLY" LN_TUNING = "LN_TUNING" VERA = "VERA" + XLORA = "XLORA" class TaskType(str, enum.Enum): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 27d0eb9bfe..f08bb76e05 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -177,6 +177,8 @@ def renamed_dora_weights(k): to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] + elif config.peft_type == PeftType.XLORA: + to_return = {k: state_dict[k] for k in state_dict if "internal_xlora_classifier" in k} else: raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") @@ -375,6 +377,8 @@ def renamed_dora_weights(k): elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: peft_model_state_dict = state_dict + elif config.peft_type == PeftType.XLORA: + peft_model_state_dict = state_dict else: raise NotImplementedError diff --git a/tests/test_xlora.py b/tests/test_xlora.py new file mode 100644 index 0000000000..8e8fa2ab41 --- /dev/null +++ b/tests/test_xlora.py @@ -0,0 +1,307 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model +from peft.peft_model import PeftModel +from peft.utils import infer_device + + +class TestXlora: + torch_device = infer_device() + + model_id = "facebook/opt-125m" + num_loras = 4 + + @pytest.fixture(scope="function") + def lora_dir(self, tmp_path_factory): + return tmp_path_factory.mktemp("lora") + + @pytest.fixture(scope="function") + def lora_embedding_dir(self, tmp_path_factory): + return tmp_path_factory.mktemp("lora_embedding") + + @pytest.fixture(scope="function") + def saved_lora_adapters(self, lora_dir): + file_names = [] + for i in range(1, self.num_loras + 1): + torch.manual_seed(i) + lora_config = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) + model = AutoModelForCausalLM.from_pretrained(self.model_id) + peft_model = get_peft_model(model, lora_config) + file_name = os.path.join(lora_dir, f"checkpoint-{i}") + peft_model.save_pretrained(file_name) + file_names.append(file_name) + return file_names + + @pytest.fixture(scope="function") + def saved_lora_embedding_adapters(self, lora_embedding_dir): + file_names = [] + for i in range(1, self.num_loras + 1): + torch.manual_seed(i) + lora_config = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False, target_modules=["embed_tokens"]) + model = AutoModelForCausalLM.from_pretrained(self.model_id) + peft_model = get_peft_model(model, lora_config) + file_name = os.path.join(lora_embedding_dir, f"checkpoint-{i}") + peft_model.save_pretrained(file_name) + file_names.append(file_name) + return file_names + + @pytest.fixture(scope="function") + def tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True, device_map=self.torch_device) + return tokenizer + + @pytest.fixture(scope="function") + def embedding_model(self, saved_lora_embedding_adapters): + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model.config.use_cache = False + adapters = {str(i): file_name for i, file_name in enumerate(saved_lora_embedding_adapters)} + + peft_config = XLoraConfig( + task_type=TaskType.CAUSAL_LM, + peft_type=PeftType.XLORA, + hidden_size=model.config.hidden_size, + xlora_depth=8, + adapters=adapters, + ) + model = get_peft_model(model, peft_config).to(self.torch_device) + return model + + @pytest.fixture(scope="function") + def model(self, saved_lora_adapters): + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model.config.use_cache = False + adapters = {str(i): file_name for i, file_name in enumerate(saved_lora_adapters)} + + peft_config = XLoraConfig( + task_type=TaskType.CAUSAL_LM, + peft_type=PeftType.XLORA, + hidden_size=model.config.hidden_size, + xlora_depth=8, + adapters=adapters, + ) + model = get_peft_model(model, peft_config).to(self.torch_device) + return model + + @pytest.fixture(scope="function") + def model_layerwise(self, saved_lora_adapters): + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model.config.use_cache = False + adapters = {str(i): file_name for i, file_name in enumerate(saved_lora_adapters)} + + peft_config = XLoraConfig( + task_type=TaskType.CAUSAL_LM, + peft_type=PeftType.XLORA, + hidden_size=model.config.hidden_size, + xlora_depth=8, + adapters=adapters, + layerwise_scalings=True, + ) + model = get_peft_model(model, peft_config).to(self.torch_device) + return model + + def test_functional(self, tokenizer, model): + model.enable_scalings_logging() + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + def test_scalings_logging_methods(self, tokenizer, model): + model.enable_scalings_logging() + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + _ = model.get_latest_scalings() + # 32 is the numeber of max scalings. 3 is the number of prompt tokens. + assert 32 + 3 >= len(model.get_scalings_log()) > 0 + + model.disable_scalings_logging() + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + assert 32 >= len(model.get_scalings_log()) > 0 + + bucketed = model.get_bucketed_scalings_log() + keys = bucketed.keys() + # One bucket for prompt (seqlen=...) and one for the completion (seqlen=1) + assert len(bucketed) == 2 + # One bucket for prompt (which has 1 elem) + assert len(bucketed[max(keys)][0]) == 1 + assert len(bucketed[max(keys)][1]) == 1 + assert bucketed[max(keys)][0][0] == 0 + # One bucket for completions with bucket name 1 + assert len(bucketed[1][0]) > 1 + assert len(bucketed[1][1]) > 1 + assert bucketed[1][0][0] > 0 + + model.clear_scalings_log() + assert len(model.get_scalings_log()) == 0 + + def test_misc_methods(self, tokenizer, model): + model.set_global_scaling_weight(1.5) + assert model.internal_xlora_classifier.config.global_scaling_weight == 1.5 + assert model.get_global_scaling_weight() == 1.5 + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + assert str(model) is not None + + def test_save_load_functional(self, tokenizer, model, tmp_path): + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + before_logits = outputs[: inputs.shape[1] :] + assert torch.isfinite(before_logits).all() + + model.save_pretrained(save_directory=tmp_path) + + del model + + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model.config.use_cache = False + model = PeftModel.from_pretrained(model=model, model_id=tmp_path).to(self.torch_device) + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + after_logits = outputs[: inputs.shape[1] :] + assert torch.isfinite(after_logits).all() + assert torch.equal(after_logits, before_logits) + + def test_save_load_functional_pt(self, tokenizer, model, tmp_path): + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + before_logits = outputs[: inputs.shape[1] :] + assert torch.isfinite(before_logits).all() + + model.save_pretrained(save_directory=tmp_path, safe_serialization=False) + + del model + + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model.config.use_cache = False + model = PeftModel.from_pretrained(model=model, model_id=tmp_path, safe_serialization=False).to( + self.torch_device + ) + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + after_logits = outputs[: inputs.shape[1] :] + assert torch.isfinite(after_logits).all() + assert torch.equal(after_logits, before_logits), (after_logits, before_logits) + + def test_topk_lora(self, tokenizer, model): + model.set_topk_lora(2) + assert model.internal_xlora_classifier.config.top_k_lora == 2 + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + def test_softmax_topk(self, tokenizer, model): + # Just reach in to set the config + model.internal_xlora_classifier.config.top_k_lora = 2 + model.internal_xlora_classifier.config.enable_softmax = False + model.internal_xlora_classifier.config.enable_softmax_topk = True + + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + def test_set_override_scaling_pass_value(self, model): + # Defaults to 0 + assert model.internal_xlora_classifier.override_scaling_pass_value == 0.0 + + # Set it to 2 and make sure it actually is + model.set_scaling_pass_value(2) + assert model.internal_xlora_classifier.override_scaling_pass_value == 2 + assert model.internal_xlora_classifier.config.scaling_pass_value == 2 + + # Set it to 2 and make sure it is 1/a + model.set_scaling_pass_value(None) + assert model.internal_xlora_classifier.override_scaling_pass_value == 1 / self.num_loras + assert model.internal_xlora_classifier.config.scaling_pass_value == 1 / self.num_loras + + def test_functional_layerwise(self, tokenizer, model_layerwise): + model_layerwise.enable_scalings_logging() + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model_layerwise.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + + def test_disable_adapter(self, tokenizer, model): + model.enable_scalings_logging() + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + with model.disable_adapter(): + outputs_disabled = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs_disabled[: inputs.shape[1] :]).all() + assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + assert not torch.equal(outputs, outputs_disabled) + + def test_functional_embedding(self, tokenizer, embedding_model): + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = embedding_model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=32, + ) + assert torch.isfinite(outputs[: inputs.shape[1] :]).all()