Skip to content

Commit

Permalink
FEAT Add OLoRA initialization strategy to LoRA (#1828)
Browse files Browse the repository at this point in the history
  • Loading branch information
tokenizer-decode committed Jun 12, 2024
1 parent 8843a76 commit 2f5360a
Show file tree
Hide file tree
Showing 9 changed files with 576 additions and 46 deletions.
9 changes: 9 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
```
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning).

### OLoRA
[OLoRA](https://arxiv.org/abs/2406.01775) utilizes QR decomposition to initialize the LoRA adapters. OLoRA translates the base weights of the model by a factor of their QR decompositions, i.e., it mutates the weights before performing any training on them. This approach significantly improves stability, accelerates convergence speed, and ultimately achieves superior performance.

You just need to pass a single additional option to use OLoRA:
```python
from peft import LoraConfig
config = LoraConfig(init_lora_weights="olora", ...)
```
For more advanced usage, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/olora_finetuning).
### LoftQ

#### Standard approach
Expand Down
84 changes: 84 additions & 0 deletions examples/olora_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# OLoRA: Orthonormal Low Rank Adaptation of Large Language Models

## Introduction
[OLoRA](https://arxiv.org/abs/2406.01775) is a novel approach that leverages orthonormal low rank adaptation through QR decomposition. Unlike the default LoRA implementation, OLoRA decomposes original weights into their $\mathbf{Q}$ and $\mathbf{R}$ parts, and then uses the first `rank` rows of $\mathbf{R}$ and the first `rank` columns of $\mathbf{Q}$ to initialize $\mathbf{A}$ and $\mathbf{B}$, respectively. This results in significantly faster convergence, more stable training, and superior performance.

## Quick start
```python
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
dataset = load_dataset("imdb", split="train[:1%]")
lora_config = LoraConfig(
init_lora_weights="olora"
)
peft_model = get_peft_model(model, lora_config)
trainer = SFTTrainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("olora-opt-350m")
```

There is no additional change needed to your standard LoRA procedure, except for specifying `init_lora_weights = "olora"` option in your lora configuration.

Additionally you can refer to olora finetuning script.
Run the script simply by running:
```bash
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m
```
OLoRA also supports quantization. To use 4-bit quantization try:
```bash
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m --quantize
```


## Use the model
You can load and use the model as any other 🤗 PEFT model
```python
from peft import PeftModel
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
olora_model = PeftModel.from_pretrained(model, "olora-opt-350m")
```

## OLoRA and LoRA
OLoRA differs from LoRA in that it mutates the original weights. To utilize multiple adapters simultaneously, you can leverage the `path_initial_model_for_weight_conversion` option. Below is a simple template illustrating how to convert OLoRA to conventional LoRA:
```python
base_model = AutoModel.from_pretrained("facebook/opt-350m")
olora_config = LoraConfig(
...
init_lora_weights = "olora" # Initialize the model with OLoRA
)
olora_model = get_peft_model(base_model, olora_config)
init_path = <path-to-untrained-olora-model>
olora_model.save_pretrained(init_path) # Save the model *before* performing any training

# Train the model
train(olora_model) # Your training loop

#Save the model after training
olora_model.save_pretrained(output_dir, path_initial_model_for_weight_conversion=init_path)
```
After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model.

## Citation
```
@misc{büyükakyüz2024olora,
title={OLoRA: Orthonormal Low-Rank Adaptation of Large Language Models},
author={Kerim Büyükakyüz},
year={2024},
eprint={2406.01775},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
184 changes: 184 additions & 0 deletions examples/olora_finetuning/olora_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List

import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from peft import (
LoraConfig,
get_peft_model,
)


def train(
base_model: str = "path/to/model",
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "olora",
batch_size: int = 16,
num_epochs: int = 1,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 16,
quantize: bool = False,
eval_step: int = 100,
save_step: int = 100,
device_map: str = "auto",
lora_r: int = 32,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = None,
init_lora_weights="olora",
):
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
if quantize
else None,
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)

result["labels"] = result["input_ids"].copy()

return result

def generate_and_tokenize_prompt(example):
full_prompt = generate_prompt(example)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt

config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
init_lora_weights=init_lora_weights,
)
model = get_peft_model(model, config)

data = load_dataset(data_path)

train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=batch_size,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=100,
optim="adamw_torch",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=eval_step,
save_steps=save_step,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
trainer.train()
model.save_pretrained(output_dir)


def generate_prompt(example):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{example["instruction"]}
### Response:
{example["output"]}"""


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str, default="path/to/model")
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned")
parser.add_argument("--output_dir", type=str, default="olora")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--cutoff_len", type=int, default=256)
parser.add_argument("--val_set_size", type=int, default=16)
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--eval_step", type=int, default=100)
parser.add_argument("--save_step", type=int, default=100)
parser.add_argument("--device_map", type=str, default="auto")
parser.add_argument("--lora_r", type=int, default=32)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--lora_target_modules", type=str, default=None)
parser.add_argument("--init_lora_weights", type=str, default="olora")

args = parser.parse_args()

train(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
quantize=args.quantize,
eval_step=args.eval_step,
save_step=args.save_step,
device_map=args.device_map,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
)
Loading

0 comments on commit 2f5360a

Please sign in to comment.