From 2f5360a7da22a236b5ad4c059572fff5321c867c Mon Sep 17 00:00:00 2001 From: Kerim <99087793+tokenizer-decode@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:46:43 +0300 Subject: [PATCH] FEAT Add OLoRA initialization strategy to LoRA (#1828) --- docs/source/developer_guides/lora.md | 9 + examples/olora_finetuning/README.md | 84 ++++++++ examples/olora_finetuning/olora_finetuning.py | 184 ++++++++++++++++++ src/peft/peft_model.py | 69 ++++--- src/peft/tuners/lora/config.py | 23 +-- src/peft/tuners/lora/layer.py | 27 ++- src/peft/tuners/lora/model.py | 17 +- tests/test_gpu_examples.py | 113 ++++++++++- tests/test_initialization.py | 96 +++++++++ 9 files changed, 576 insertions(+), 46 deletions(-) create mode 100644 examples/olora_finetuning/README.md create mode 100644 examples/olora_finetuning/olora_finetuning.py diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index 036fc2a4ca..f73da6e0a4 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -54,6 +54,15 @@ lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...) ``` For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning). +### OLoRA +[OLoRA](https://arxiv.org/abs/2406.01775) utilizes QR decomposition to initialize the LoRA adapters. OLoRA translates the base weights of the model by a factor of their QR decompositions, i.e., it mutates the weights before performing any training on them. This approach significantly improves stability, accelerates convergence speed, and ultimately achieves superior performance. + +You just need to pass a single additional option to use OLoRA: +```python +from peft import LoraConfig +config = LoraConfig(init_lora_weights="olora", ...) +``` +For more advanced usage, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/olora_finetuning). ### LoftQ #### Standard approach diff --git a/examples/olora_finetuning/README.md b/examples/olora_finetuning/README.md new file mode 100644 index 0000000000..fd6e5c3e0c --- /dev/null +++ b/examples/olora_finetuning/README.md @@ -0,0 +1,84 @@ +# OLoRA: Orthonormal Low Rank Adaptation of Large Language Models + +## Introduction +[OLoRA](https://arxiv.org/abs/2406.01775) is a novel approach that leverages orthonormal low rank adaptation through QR decomposition. Unlike the default LoRA implementation, OLoRA decomposes original weights into their $\mathbf{Q}$ and $\mathbf{R}$ parts, and then uses the first `rank` rows of $\mathbf{R}$ and the first `rank` columns of $\mathbf{Q}$ to initialize $\mathbf{A}$ and $\mathbf{B}$, respectively. This results in significantly faster convergence, more stable training, and superior performance. + +## Quick start +```python +import torch +from peft import LoraConfig, get_peft_model +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import SFTTrainer +from datasets import load_dataset + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") +dataset = load_dataset("imdb", split="train[:1%]") +lora_config = LoraConfig( + init_lora_weights="olora" +) +peft_model = get_peft_model(model, lora_config) +trainer = SFTTrainer( + model=peft_model, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=512, + tokenizer=tokenizer, +) +trainer.train() +peft_model.save_pretrained("olora-opt-350m") +``` + +There is no additional change needed to your standard LoRA procedure, except for specifying `init_lora_weights = "olora"` option in your lora configuration. + +Additionally you can refer to olora finetuning script. +Run the script simply by running: +```bash +python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m +``` +OLoRA also supports quantization. To use 4-bit quantization try: +```bash +python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m --quantize +``` + + +## Use the model +You can load and use the model as any other 🤗 PEFT model +```python +from peft import PeftModel +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") +olora_model = PeftModel.from_pretrained(model, "olora-opt-350m") +``` + +## OLoRA and LoRA +OLoRA differs from LoRA in that it mutates the original weights. To utilize multiple adapters simultaneously, you can leverage the `path_initial_model_for_weight_conversion` option. Below is a simple template illustrating how to convert OLoRA to conventional LoRA: +```python +base_model = AutoModel.from_pretrained("facebook/opt-350m") +olora_config = LoraConfig( + ... + init_lora_weights = "olora" # Initialize the model with OLoRA +) +olora_model = get_peft_model(base_model, olora_config) +init_path = +olora_model.save_pretrained(init_path) # Save the model *before* performing any training + +# Train the model +train(olora_model) # Your training loop + +#Save the model after training +olora_model.save_pretrained(output_dir, path_initial_model_for_weight_conversion=init_path) +``` +After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model. + +## Citation +``` +@misc{büyükakyüz2024olora, + title={OLoRA: Orthonormal Low-Rank Adaptation of Large Language Models}, + author={Kerim Büyükakyüz}, + year={2024}, + eprint={2406.01775}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` \ No newline at end of file diff --git a/examples/olora_finetuning/olora_finetuning.py b/examples/olora_finetuning/olora_finetuning.py new file mode 100644 index 0000000000..13e4e4f666 --- /dev/null +++ b/examples/olora_finetuning/olora_finetuning.py @@ -0,0 +1,184 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import torch +import transformers +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +from peft import ( + LoraConfig, + get_peft_model, +) + + +def train( + base_model: str = "path/to/model", + data_path: str = "yahma/alpaca-cleaned", + output_dir: str = "olora", + batch_size: int = 16, + num_epochs: int = 1, + learning_rate: float = 3e-4, + cutoff_len: int = 256, + val_set_size: int = 16, + quantize: bool = False, + eval_step: int = 100, + save_step: int = 100, + device_map: str = "auto", + lora_r: int = 32, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_target_modules: List[str] = None, + init_lora_weights="olora", +): + model = AutoModelForCausalLM.from_pretrained( + base_model, + device_map=device_map, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + if quantize + else None, + torch_dtype=torch.float16, + ) + + tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) + + def tokenize(prompt, add_eos_token=True): + result = tokenizer( + prompt, + truncation=True, + max_length=cutoff_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != tokenizer.eos_token_id + and len(result["input_ids"]) < cutoff_len + and add_eos_token + ): + result["input_ids"].append(tokenizer.eos_token_id) + result["attention_mask"].append(1) + + result["labels"] = result["input_ids"].copy() + + return result + + def generate_and_tokenize_prompt(example): + full_prompt = generate_prompt(example) + tokenized_full_prompt = tokenize(full_prompt) + return tokenized_full_prompt + + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + init_lora_weights=init_lora_weights, + ) + model = get_peft_model(model, config) + + data = load_dataset(data_path) + + train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) + train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) + val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) + + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=val_data, + args=transformers.TrainingArguments( + per_device_train_batch_size=batch_size, + warmup_steps=100, + num_train_epochs=num_epochs, + learning_rate=learning_rate, + fp16=True, + logging_steps=100, + optim="adamw_torch", + evaluation_strategy="steps", + save_strategy="steps", + eval_steps=eval_step, + save_steps=save_step, + output_dir=output_dir, + save_total_limit=3, + load_best_model_at_end=True, + ), + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + ) + trainer.train() + model.save_pretrained(output_dir) + + +def generate_prompt(example): + return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. + ### Instruction: + {example["instruction"]} + ### Response: + {example["output"]}""" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--base_model", type=str, default="path/to/model") + parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned") + parser.add_argument("--output_dir", type=str, default="olora") + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=3e-4) + parser.add_argument("--cutoff_len", type=int, default=256) + parser.add_argument("--val_set_size", type=int, default=16) + parser.add_argument("--quantize", action="store_true") + parser.add_argument("--eval_step", type=int, default=100) + parser.add_argument("--save_step", type=int, default=100) + parser.add_argument("--device_map", type=str, default="auto") + parser.add_argument("--lora_r", type=int, default=32) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--lora_dropout", type=float, default=0.05) + parser.add_argument("--lora_target_modules", type=str, default=None) + parser.add_argument("--init_lora_weights", type=str, default="olora") + + args = parser.parse_args() + + train( + base_model=args.base_model, + data_path=args.data_path, + output_dir=args.output_dir, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + cutoff_len=args.cutoff_len, + val_set_size=args.val_set_size, + quantize=args.quantize, + eval_step=args.eval_step, + save_step=args.save_step, + device_map=args.device_map, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + ) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index becad7280d..c92a91be79 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -193,6 +193,7 @@ def save_pretrained( save_embedding_layers: Union[str, bool] = "auto", is_main_process: bool = True, convert_pissa_to_lora: Optional[str] = None, + path_initial_model_for_weight_conversion: Optional[str] = None, **kwargs: Any, ) -> None: r""" @@ -215,13 +216,15 @@ def save_pretrained( is_main_process (`bool`, *optional*): Whether the process calling this is the main process or not. Will default to `True`. Will not save the checkpoint if not on the main process, which is important for multi device setups (e.g. DDP). - convert_pissa_to_lora (`str`): - The path to the initialized PiSSA adapter, which is obtained after initializing the model with PiSSA - and before performing any training. When `convert_pissa_to_lora` is not None, the difference in PISSA - before and after fine-tuning is calculated. This difference can be represented as the parameters of a - of a standard LoRA adapter. Using this converted adapter does not require changes to the base model, - thus conveniently allowing the use of multiple PISSA and LoRA adapters, and the activation or - deactivation of any adapters. + convert_pissa_to_lora (`str, *optional*`): + Deprecated. Use `path_initial_model_for_weight_conversion` instead. + path_initial_model_for_weight_conversion (`str, *optional*`): + The path to the initialized adapter, which is obtained after initializing the model with PiSSA or OLoRA + and before performing any training. When `path_initial_model_for_weight_conversion` is not None, the + difference in adapter before and after fine-tuning is calculated. This difference can be represented as + the parameters of a standard LoRA adapter. Using this converted adapter does not require changes to the + base model, thus conveniently allowing the use of multiple PiSSA or OLoRA adapters with LoRA adapters, + and the activation or deactivation of any adapters. kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. """ @@ -239,20 +242,36 @@ def save_pretrained( f"You passed an invalid `selected_adapters` arguments, current supported adapter names are" f" {list(self.peft_config.keys())} - got {selected_adapters}." ) + # TODO: remove deprecated parameter in PEFT v0.14.0 + if convert_pissa_to_lora is not None: + warnings.warn( + "`convert_pissa_to_lora` is deprecated and will be removed in a future version. " + "Use `path_initial_model_for_weight_conversion` instead." + ) + path_initial_model_for_weight_conversion = convert_pissa_to_lora - def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kwargs): - if not str(peft_config.init_lora_weights).startswith("pissa"): - warnings.warn("`convert_pissa_to_lora` only works for converting a PiSSA adapter to a LoRA adapter") - initial_adapter = os.path.basename(convert_pissa_to_lora) + def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs): + if not any( + str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora", "true"] + ): + warnings.warn( + "`path_initial_model_for_weight_conversion` only works for converting a PiSSA or OLoRA adapter to a LoRA adapter" + ) + initial_adapter = os.path.basename(path_initial_model_for_weight_conversion) self.load_adapter( - os.path.dirname(convert_pissa_to_lora), subfolder=initial_adapter, adapter_name=initial_adapter + os.path.dirname(path_initial_model_for_weight_conversion), + subfolder=initial_adapter, + adapter_name=initial_adapter, ) - if str(self.peft_config[initial_adapter].init_lora_weights).startswith("pissa"): + if any( + str(self.peft_config[initial_adapter].init_lora_weights).lower().startswith(prefix) + for prefix in ["pissa", "olora"] + ): raise ValueError( - "The `init_lora_weights` parameter of the initial PiSSA adapter should be set to `True`. " - "Otherwise, `self.load_adapter` will subtract the principal singular value and vector again based on the residual model." + "The `init_lora_weights` parameter of the initial adapter should be set to `True`. " + "Otherwise, `self.load_adapter` will subtract the decomposed values again based on the residual model." ) - output_state_dict = self.base_model.subtract_pissa_init(output_state_dict, initial_adapter, kwargs) + output_state_dict = self.base_model.subtract_mutated_init(output_state_dict, initial_adapter, kwargs) self.delete_adapter(adapter_name) return output_state_dict @@ -294,9 +313,11 @@ def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kw # not supported in safetensors. for shared_tensor_name in names[1:]: output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone() - if convert_pissa_to_lora is not None: - output_state_dict = save_pissa_as_lora( - peft_config, convert_pissa_to_lora, output_state_dict, kwargs + if path_initial_model_for_weight_conversion is not None: + peft_config.init_lora_weights = True + peft_config.save_pretrained(path_initial_model_for_weight_conversion) + output_state_dict = save_mutated_as_lora( + peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs ) safe_save_file( output_state_dict, @@ -304,9 +325,11 @@ def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kw metadata={"format": "pt"}, ) elif is_main_process: - if convert_pissa_to_lora is not None: - output_state_dict = save_pissa_as_lora( - peft_config, convert_pissa_to_lora, output_state_dict, kwargs + if path_initial_model_for_weight_conversion is not None: + peft_config.init_lora_weights = True + peft_config.save_pretrained(path_initial_model_for_weight_conversion) + output_state_dict = save_mutated_as_lora( + peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs ) torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) @@ -335,7 +358,7 @@ def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kw auto_mapping_dict = None if is_main_process: - if convert_pissa_to_lora is not None: + if path_initial_model_for_weight_conversion is not None: peft_config.init_lora_weights = True peft_config.r *= 2 peft_config.lora_alpha *= 2 diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 3317c5b753..6e2b7b7452 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -73,19 +73,19 @@ class LoraConfig(PeftConfig): Otherwise, it will use the original default value of `lora_alpha/r`. modules_to_save (`List[str]`): List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. - init_lora_weights (`bool` | `Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"]`): + init_lora_weights (`bool` | `Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]`): How to initialize the weights of the adapter layers. Passing True (default) results in the default initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to - completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing - 'pissa' results in the initialization of PiSSA, which converge more rapidly than LoRA and ultimately - achieve superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to - further enhancements. Passing 'pissa_niter_[number of iters]' initiates Fast-SVD-based PiSSA - initialization, where [number of iters] indicates the number of subspace iterations to perform FSVD, and - must be a nonnegative integer. When the [number of iters] is set to 16, it can complete the initialization - of a 7b model within seconds, and the training effect is approximately equivalent to using SVD. For more - information, see Principal Singular values and Singular vectors - Adaptation. + completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Pass + `'olora'` to use OLoRA initialization. Passing 'pissa' results in the initialization of PiSSA, which + converge more rapidly than LoRA and ultimately achieve superior performance. Moreover, PiSSA reduces the + quantization error compared to QLoRA, leading to further enhancements. Passing 'pissa_niter_[number of + iters]' initiates Fast-SVD-based PiSSA initialization, where [number of iters] indicates the number of + subspace iterations to perform FSVD, and must be a nonnegative integer. When the [number of iters] is set + to 16, it can complete the initialization of a 7b model within seconds, and the training effect is + approximately equivalent to using SVD. For more information, see Principal Singular values and Singular vectors Adaptation. layers_to_transform (`Union[List[int], int]`): The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices that are specified in this list. If a single integer is passed, it will apply the transformations on the @@ -163,7 +163,7 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) - init_lora_weights: bool | Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"] = field( + init_lora_weights: bool | Literal["gaussian", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"] = field( default=True, metadata={ "help": ( @@ -171,6 +171,7 @@ class LoraConfig(PeftConfig): "initialization from the reference implementation from Microsoft. Passing 'gaussian' results " "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " "to False leads to completely random initialization and is discouraged." + "Passing 'olora' results in OLoRA initialization." "Passing 'pissa' results in PiSSA initialization." "Passing 'pissa_niter_[number of iters]' initiates Fast-SVD-based PiSSA initialization, " "where [number of iters] indicates the number of subspace iterations to perform fsvd, and must be a nonnegative integer." diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 2dc8b2e68a..c72ca65b06 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -24,6 +24,7 @@ from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.integrations import dequantize_module_weight from peft.utils.other import transpose from .config import LoraConfig @@ -115,11 +116,12 @@ def update_layer( if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): self.pissa_init(adapter_name, init_lora_weights) + elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": + self.olora_init(adapter_name) elif init_lora_weights == "loftq": self.loftq_init(adapter_name) elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) - # call this before dora_init self._move_adapter_to_device_of_base_layer(adapter_name) @@ -150,6 +152,28 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights): nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name]) + def olora_init(self, adapter_name): + dtype = self.base_layer.weight.dtype + if dtype in [torch.int8, torch.uint8]: + weight_tensor = dequantize_module_weight(self.base_layer) + elif dtype in [torch.float32, torch.float16, torch.bfloat16]: + weight_tensor = self.base_layer.weight + else: + raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.") + scale_factor = self.scaling[adapter_name] + r = self.r[adapter_name] + weight_tensor = weight_tensor.to(torch.float32) + Q, R = torch.linalg.qr(weight_tensor.data) + + Qr, Rr = Q[:, :r], R[:r] + + self.lora_A[adapter_name].weight.data = Rr.contiguous() + self.lora_B[adapter_name].weight.data = Qr.contiguous() + + weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight + weight_tensor = weight_tensor.to(dtype) + self.get_base_layer().weight.data = weight_tensor + def pissa_init(self, adapter_name, init_lora_weights): weight = self.get_base_layer().weight dtype = weight.dtype @@ -159,7 +183,6 @@ def pissa_init(self, adapter_name, init_lora_weights): "Subsequently, re-quantize the residual model to help minimize quantization errors." ) weight = weight.to(torch.float32) - if init_lora_weights == "pissa": # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 1087415607..6b54bcb1ea 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -846,25 +846,24 @@ def unload(self) -> torch.nn.Module: """ return self._unload_and_optionally_merge(merge=False) - def subtract_pissa_init( - self, output_state_dict: dict[str, torch.Tensor], adapter_name: str = "pissa_init", kwargs=None - ): + def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None): """ - This function can calculate the updates of the PiSSA by comparing the parameters of the PiSSA adapter in - `output_state_dict` with the initial values of PiSSA in `adapter_name`, thus converting PiSSA to LoRA. + This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA | + OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus + converting [PiSSA | OLoRA] to LoRA. """ for name, param in self.model.named_parameters(): if ( param.data.dtype != torch.float32 and param.data.dtype != torch.float16 and param.data.dtype != torch.bfloat16 - ): + ) and adapter_name.startswith("pissa"): warnings.warn( r"Note that Quant(W_res) + AB != Quant(W) + \Delta(AB); " "the converted LoRA, when combined with W or Quant(W), may introduce a certain gap in the fine-tuned model. " "Therefore, we recommend directly using the Quant(W_res) in conjunction with the PiSSA adapter. " ) - pissa_init_state_dict = get_peft_model_state_dict( + mutated_init_state_dict = get_peft_model_state_dict( self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name, @@ -876,11 +875,11 @@ def subtract_pissa_init( ## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'. if "lora_A" in name: tensors_lora[name] = torch.cat( - [output_state_dict[name], pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=0 + [output_state_dict[name], mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=0 ) elif "lora_B" in name: tensors_lora[name] = torch.cat( - [output_state_dict[name], -pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=1 + [output_state_dict[name], -mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=1 ) return tensors_lora diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index ccf7378dbb..df544b606c 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1656,7 +1656,7 @@ def forward(self, x): assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 # save the model with conversion - peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_pissa_to_lora=tmp_path / "init-model") + peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_mutated_to_lora=tmp_path / "init-model") model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") output_converted = model_converted(data)[0] @@ -1671,6 +1671,117 @@ def forward(self, x): assert not torch.allclose(output_finetuned_pissa, output_converted, atol=tol, rtol=tol) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +class TestOLoRA: + r""" + Tests for OLoRA to ensure that it reduces the quantization error compared to normal LoRA quantization. + """ + + # The error factor indicates by how much the quantization error should be decreased when using OLoRA compared to + # quantization without OLoRA. Thus 1.03 means that the error should be decreased by 3% at least. This is a very + # conservative value to prevent flakiness, in practice most gains are > 1.5 + error_factor = 1.2 + + def quantize_model(self, model, num_bits=4, device="cuda"): + # Quantize the `weight.data` of the linear layer in the model to `num_bits` and store it with full precision. + quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and "lm_head" not in name: + quantized_weight, max_abs, shape = quantizer.quantize_block(module.weight.data.to(device)) + module.weight.data = quantizer.dequantize_block(quantized_weight, max_abs, shape) + return model + + def nuclear_norm(self, base_model, quantized_model): + # Calculate the nuclear norm (sum of singular values) of the error matrices between the `quantized_model` and the `base_model`. + error_list = [] + for name, module in base_model.named_modules(): + if isinstance(module, torch.nn.Linear) and "lm_head" not in name: + quant_module = quantized_model.get_submodule(name) + error_list.append(torch.linalg.svdvals(module.weight.data - quant_module.weight.data).sum()) + return torch.Tensor(error_list).sum() + + def get_errors( + self, + tmp_path, + bits=4, + device="cuda", + model_id="hf-internal-testing/tiny-random-BloomForCausalLM", + ): + # Comparing the quantized LoRA model to the base model, vs the OLoRA quantized model to the base model. + # We expect the OLoRA quantized model to have less error than the normal LoRA quantized model. + + cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM + base_model = cls.from_pretrained(model_id).eval().to(device) + task_type = TaskType.SEQ_2_SEQ_LM if base_model.config.is_encoder_decoder else TaskType.CAUSAL_LM + + # logits from the normal quantized LoRA model + target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"] + lora_config = LoraConfig(task_type=task_type, target_modules=target_modules) + + qlora_model = self.quantize_model(cls.from_pretrained(model_id).eval().to(device), bits, device) + qlora_model = get_peft_model( + qlora_model, + lora_config, + ) + qlora_model = qlora_model.merge_and_unload() + qlora_error = self.nuclear_norm(base_model, qlora_model) + del qlora_model + gc.collect() + torch.cuda.empty_cache() + + # logits from quantized LoRA model using OLoRA + lora_config = LoraConfig( + task_type=task_type, + init_lora_weights="olora", + target_modules=target_modules, + ) + olora_model = cls.from_pretrained(model_id).eval().to(device) + olora_model = get_peft_model(olora_model, lora_config) + + # save LoRA weights, they should be initialized such that they minimize the quantization error + olora_model.base_model.peft_config["default"].init_lora_weights = True + olora_model.save_pretrained(tmp_path / "olora_model") + + olora_model = olora_model.unload() + olora_model.save_pretrained(tmp_path / "residual_model") + + del olora_model + gc.collect() + torch.cuda.empty_cache() + + # now load quantized model and apply OLoRA-initialized weights on top + qolora_model = self.quantize_model( + cls.from_pretrained(tmp_path / "residual_model").eval().to(device), bits, device + ) + qolora_model = PeftModel.from_pretrained(qolora_model, tmp_path / "olora_model") + qolora_model = qolora_model.merge_and_unload() + qolora_error = self.nuclear_norm(base_model, qolora_model) + del qolora_model + gc.collect() + torch.cuda.empty_cache() + + assert qlora_error > 0.0 + assert qolora_error > 0.0 + + # next, check that OLoRA quantization errors are smaller than LoRA errors by a certain margin + assert qolora_error < (qlora_error / self.error_factor) + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_bloomz_olora_4bit(self, device, tmp_path): + # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model + # using OLoRA. When quantizing, we expect a certain level of error. However, we expect the OLoRA quantized + # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the + # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training. + # We still apply LoRA for the test for consistency. + + self.get_errors(bits=4, device=device, tmp_path=tmp_path) + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_bloomz_olora_8bit(self, device, tmp_path): + # Same test as test_bloomz_olora_4bit but with 8 bits. + self.get_errors(bits=8, device=device, tmp_path=tmp_path) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") class TestLoftQ: r""" diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 589e046ad6..5958bb3f62 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -266,6 +266,15 @@ def test_lora_pissa_linear_init_default(self, data): peft_model = get_peft_model(deepcopy(model), config) assert torch.allclose(output, peft_model(data)[0], atol=1e-06) + def test_lora_olora_linear_init_default(self, data): + model = self.get_model() + output = model(data)[0] + + # Both OLoRA and olora should work + config = LoraConfig(init_lora_weights="OLoRA", target_modules=["linear"]) + peft_model = get_peft_model(deepcopy(model), config) + assert torch.allclose(output, peft_model(data)[0], atol=1e-06) + def test_lora_pissa_conversion_same_output_after_loading(self, data, tmp_path): model = self.get_model() output_base = model(data)[0] @@ -300,11 +309,98 @@ def test_lora_pissa_conversion_same_output_after_loading(self, data, tmp_path): ) # save the model with conversion + peft_model.save_pretrained( + tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # TODO: remove test for deprecated arg in PEFT v0.14.0 + def test_lora_pissa_conversion_same_output_after_loading_with_deprecated_arg(self, data, tmp_path): + model = self.get_model() + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + peft_model.peft_config["default"].init_lora_weights = "pissa" + + tol = 1e-06 + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_pissa = peft_model(data)[0] + peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_pissa_to_lora=tmp_path / "init-model") model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") output_converted = model_converted(data)[0] assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol) + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # TODO: remove test for deprecated warning in PEFT v0.14.0 + def test_lora_pissa_conversion_deprecated_warning(self, data, tmp_path): + model = self.get_model() + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + warning_message = "`convert_pissa_to_lora` is deprecated and will be removed in a future version. Use `path_initial_model_for_weight_conversion` instead." + # Test the warning + with pytest.warns(UserWarning, match=warning_message): + peft_model.save_pretrained( + tmp_path / "pissa-model-converted", convert_pissa_to_lora=tmp_path / "init-model" + ) + + def test_olora_conversion_same_output_after_loading(self, data, tmp_path): + model = self.get_model() + output_base = model(data)[0] + + config = LoraConfig(init_lora_weights="olora", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.save_pretrained(tmp_path / "init-model") + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_olora = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_olora, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "olora-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_olora, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "olora-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_olora, output_converted, atol=tol, rtol=tol) # rank should be double of what it was initially assert model_converted.peft_config["default"].r == 16 assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16