Skip to content

Commit

Permalink
ENH Ephemeral GPU offload support for DoRA (#1857)
Browse files Browse the repository at this point in the history
Adds the concept of ephemeral GPU offloading, i.e. where data in compute
intense operations is copied onto the GPU before the operation is
performed, after which the result is put back on CPU memory.

This PR adds support in the DoRA initialization code, but the approach
can be applied in a number of places: when the size of the data compared
to the time to perform the operation on CPU memory is heavily time
dominant, using ephemeral transfers has a fairly small VRAM overhead
(depending on the size of the model/adapter) with orders of magnitude
speed-up in certain operations.

For example, a Llama3-8B DoRA adapter with r=64 would put an overhead of
2 x (64 x 4096 x 2 + 4096 x 4096) bytes (assuming fp16), i.e. 33 MB or
so. A Llama3-70B adapter with r=32 would have 2 x (32 x 8192 x 2 + 8192
x 8192) bytes =130 MB.

By making use of ephemeral GPU offloading, more efficient juggling of
data between GPU and CPU may become possible, i.e. where instead of
always loading as much as we can onto the GPU and then endure the CPU
slowness for whatever happens to not fit in there, we intentionally
leave a (modest) chunk of VRAM for optimizations like these, and the end
result is a much (MUCH) faster experience.
  • Loading branch information
kallewoof committed Jul 2, 2024
1 parent 1e5227f commit 1e2258d
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 9 deletions.
16 changes: 16 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ from peft import LoraConfig
config = LoraConfig(use_dora=True, ...)
```

If parts of the model or the DoRA adapter are offloaded to CPU you can get a significant speedup at the cost of some temporary (ephemeral) VRAM overhead by using `ephemeral_gpu_offload=True` in `config.runtime_config`.

```py
from peft import LoraConfig, LoraRuntimeConfig

config = LoraConfig(use_dora=True, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=True), ...)
```

A `PeftModel` with a DoRA adapter can also be loaded with `ephemeral_gpu_offload=True` flag using the `from_pretrained` method as well as the `load_adapter` method.

```py
from peft import PeftModel

model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True)
```

#### Caveats

- DoRA only supports linear and Conv2d layers at the momement.
Expand Down
103 changes: 103 additions & 0 deletions examples/ephemeral_gpu_offloading/load_with_dora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example script demonstrating the time difference loading a model with a DoRA using ephemeral GPU offloading vs doing it purely on the CPU.
Example outputs:
$ python load_with_dora.py
--- Loading model ---
Loading checkpoint shards: 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 4/4 [00:04<00:00, 1.03s/it]
--- Loading PeftModel ---
--- Done ---
Model loading time: 4.83s
PeftModel loading time: 28.14s
Use ephemeral GPU offloading: False
(Note: if this was the first time you ran the script, or if your cache was cleared, the times shown above are invalid, due to the time taken to download the model and DoRA files. Just re-run the script in this case.)
$ python load_with_dora.py --ephemeral_gpu_offload
--- Loading model ---
Loading checkpoint shards: 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 4/4 [00:03<00:00, 1.11it/s]
--- Loading PeftModel ---
--- Done ---
Model loading time: 4.28s
PeftModel loading time: 16.59s
Use ephemeral GPU offloading: True
(Note: if this was the first time you ran the script, or if your cache was cleared, the times shown above are invalid, due to the time taken to download the model and DoRA files. Just re-run the script in this case.)
"""

import argparse
import time

from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM

from peft import PeftModel


def main():
parser = argparse.ArgumentParser(description="Load a model with DoRA using ephemeral GPU offloading")
parser.add_argument("--model", type=str, default="NousResearch/Hermes-2-Pro-Mistral-7B", help="Model to load")
parser.add_argument(
"--dora",
type=str,
default="peft-internal-testing/DoRA-Hermes-2-Pro-Mistral-7B",
help="DoRA to use",
)
parser.add_argument("--ephemeral_gpu_offload", action="store_true", help="Use ephemeral GPU offloading")
parser.add_argument(
"--merge_model_path", type="str", help="Merge the model with the DoRA model and save to the given path"
)
args = parser.parse_args()

peft_model_kwargs = {
"ephemeral_gpu_offload": args.ephemeral_gpu_offload,
"max_memory": {"cpu": "256GiB"},
"device_map": {"": "cpu"},
}

# Predownload
try:
snapshot_download(repo_id=args.model)
except Exception as e:
print(f"Failed to download model: {e}")
# We continue anyway as this might be e.g. a local directory or something
try:
snapshot_download(repo_id=args.dora)
except Exception as e:
print(f"Failed to download DoRA: {e}")
# We continue anyway as this might be e.g. a local directory or something

start = time.perf_counter()
print("--- Loading model ---")
model = AutoModelForCausalLM.from_pretrained(args.model)
model_time = time.perf_counter() - start
print("--- Loading PeftModel ---")
peft_model = PeftModel.from_pretrained(model, args.dora, **peft_model_kwargs)
print("--- Done ---")
peft_model_time = time.perf_counter() - start

print(f"Model loading time: {model_time:.2f}s")
print(f"PeftModel loading time: {peft_model_time:.2f}s")
print(f"Use ephemeral GPU offloading: {args.ephemeral_gpu_offload}")

if args.merge_model_path is not None:
merged_model = peft_model.merge_and_unload(progressbar=True)
merged_model.save_pretrained(args.merge_model_path)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AdaptionPromptConfig,
AdaptionPromptModel,
LoraConfig,
LoraRuntimeConfig,
LoftQConfig,
LoraModel,
LoHaConfig,
Expand Down
10 changes: 9 additions & 1 deletion src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import json
import os
import warnings
from dataclasses import asdict, dataclass, field
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -63,7 +64,7 @@ def save_pretrained(self, save_directory: str, **kwargs) -> None:
os.makedirs(save_directory, exist_ok=True)
auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)

output_dict = asdict(self)
output_dict = self.to_dict()
# converting set type to list
for key, value in output_dict.items():
if isinstance(value, set):
Expand Down Expand Up @@ -162,6 +163,13 @@ def from_json_file(cls, path_json_file: str, **kwargs):
with open(path_json_file) as file:
json_object = json.load(file)

# Sanity check that config does not contain a runtime_config
if "runtime_config" in json_object:
warnings.warn(
"The configuration file contains a `runtime_config` key. This is ignored. Runtime configurations are only valid at runtime."
)
del json_object["runtime_config"]

return json_object

@classmethod
Expand Down
18 changes: 18 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def from_pretrained(
is_trainable: bool = False,
config: Optional[PeftConfig] = None,
autocast_adapter_dtype: bool = True,
ephemeral_gpu_offload: bool = False,
**kwargs: Any,
) -> PeftModel:
r"""
Expand Down Expand Up @@ -402,6 +403,12 @@ def from_pretrained(
loaded before calling `from_pretrained`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Only relevant for specific adapter types.
ephemeral_gpu_offload (`bool`, *optional*):
Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`. This is
useful when parts of the model and/or components (such as adapters) are kept in CPU memory until they
are needed. Rather than perform expensive operations on small data, the data is transferred to the GPU
on-demand, the operation(s) performed, and the results moved back to CPU memory. This brings a slight
momentary VRAM overhead but gives orders of magnitude speedup in certain cases.
torch_device (`str`, *optional*, defaults to None):
The device to load the adapter on. If `None`, the device will be inferred.
kwargs: (`optional`):
Expand All @@ -426,6 +433,13 @@ def from_pretrained(
else:
raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")

# Runtime configuration, if supported
if hasattr(config, "runtime_config"):
config.runtime_config.ephemeral_gpu_offload = ephemeral_gpu_offload
else:
if ephemeral_gpu_offload:
warnings.warn("Ephemeral GPU offloading is not supported for this model. Ignoring.")

if hasattr(model, "hf_device_map"):
weight_map = dict(named_module_tensors(model, recurse=True))

Expand Down Expand Up @@ -984,6 +998,7 @@ def load_adapter(
is_trainable: bool = False,
torch_device: Optional[str] = None,
autocast_adapter_dtype: bool = True,
ephemeral_gpu_offload: bool = False,
**kwargs: Any,
):
"""
Expand All @@ -1008,6 +1023,8 @@ def load_adapter(
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter
weights using float16 and bfloat16 to float32, as this is typically required for stable training, and
only affect select PEFT tuners.
ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`):
Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`.
kwargs: (`optional`):
Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
"""
Expand All @@ -1026,6 +1043,7 @@ def load_adapter(
)
].from_pretrained(
model_id,
ephemeral_gpu_offload=ephemeral_gpu_offload,
**hf_hub_download_kwargs,
)
if peft_config.is_prompt_learning and is_trainable:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# limitations under the License.

from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
from .lora import LoraConfig, LoraModel, LoftQConfig
from .lora import LoraConfig, LoraModel, LoftQConfig, LoraRuntimeConfig
from .loha import LoHaConfig, LoHaModel
from .lokr import LoKrConfig, LoKrModel
from .ia3 import IA3Config, IA3Model
Expand Down
14 changes: 12 additions & 2 deletions src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@

from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available

from .config import LoftQConfig, LoraConfig
from .config import LoftQConfig, LoraConfig, LoraRuntimeConfig
from .gptq import QuantLinear
from .layer import Conv2d, Embedding, Linear, LoraLayer
from .model import LoraModel


__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"]
__all__ = [
"LoraConfig",
"LoraRuntimeConfig",
"LoftQConfig",
"Conv2d",
"Embedding",
"LoraLayer",
"Linear",
"LoraModel",
"QuantLinear",
]


def __getattr__(name):
Expand Down
39 changes: 39 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,32 @@
from peft.utils import PeftType


@dataclass
class LoraRuntimeConfig:
"""
This is the sub-configuration class to store the runtime configurations for the model.
Args:
ephemeral_gpu_offload (`bool`):
Whether to use ephemeral GPU offloading for models partially kept in CPU memory.
"""

ephemeral_gpu_offload: bool = field(
default=False,
metadata={
"help": (
"Whether to use ephemeral GPU offloading for models partially kept in CPU memory. Ephemeral GPU offloading result in "
"the data involved in intense operations being momentarily copied over to the GPU, and the results copied "
"back to CPU. There is a momentary VRAM overhead, but operations are generally orders of magnitude faster "
"compared to performing them on the CPU. This is useful when parts of the model and/or components (such "
"as adapters) are kept in CPU memory until they are needed. Rather than perform expensive operations on "
"small data, the data is transferred to the GPU on-demand, the operation(s) performed, and the results "
"moved back to CPU memory. Currently only affects DoRA initialization."
)
},
)


@dataclass
class LoftQConfig:
"""
Expand Down Expand Up @@ -122,6 +148,8 @@ class LoraConfig(PeftConfig):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
all have separate LoRA adapters attached to them.
runtime_config (`LoraRuntimeConfig`):
Runtime configurations (which are not saved or restored).
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -282,6 +310,17 @@ class LoraConfig(PeftConfig):
)
},
)
runtime_config: LoraRuntimeConfig = field(
default_factory=LoraRuntimeConfig, metadata={"help": "Runtime configurations"}
)

def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
"""
rv = super().to_dict()
rv.pop("runtime_config")
return rv

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand Down
6 changes: 4 additions & 2 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm

def update_layer(self, *, base_layer, lora_A, lora_B, scaling) -> None:
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
dtype_is_fp16 = lora_A.dtype == torch.float16
if dtype_is_fp16:
Expand All @@ -56,8 +56,10 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling) -> None:

if dtype_is_fp16:
lora_weight = lora_weight.half()
weight_norm = self.get_weight_norm(weight, lora_weight, scaling)
weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling)

if place_on_cpu:
weight_norm = weight_norm.to("cpu")
self.weight = nn.Parameter(weight_norm, requires_grad=True)

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
Expand Down
15 changes: 13 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class LoraLayer(BaseTunerLayer):
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")

def __init__(self, base_layer: nn.Module, **kwargs) -> None:
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:
self.base_layer = base_layer
self.r = {}
self.lora_alpha = {}
Expand All @@ -54,6 +54,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.use_dora: dict[str, bool] = {}
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {}
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand Down Expand Up @@ -251,8 +252,18 @@ def dora_init(self, adapter_name: str) -> None:
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(self, "fan_in_fan_out", False))
lora_A = self.lora_A[adapter_name].weight
lora_B = self.lora_B[adapter_name].weight
place_on_cpu = self.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu")
if self.ephemeral_gpu_offload:
if lora_A.device.type == "cuda":
lora_B = lora_B.to(lora_A.device)
else:
if lora_B.device.type != "cuda":
lora_B = lora_B.to("cuda")
lora_A = lora_A.to(lora_B.device)
scaling = self.scaling[adapter_name]
dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
dora_layer.update_layer(
base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling, place_on_cpu=place_on_cpu
)
self.lora_magnitude_vector[adapter_name] = dora_layer

def _cache_store(self, key: str, value: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _create_and_replace(
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}
Expand Down
Loading

0 comments on commit 1e2258d

Please sign in to comment.