From 1e2258d7f7046c5390d6e50959dc4eb32b765f9e Mon Sep 17 00:00:00 2001 From: kallewoof Date: Tue, 2 Jul 2024 19:17:45 +0900 Subject: [PATCH] ENH Ephemeral GPU offload support for DoRA (#1857) Adds the concept of ephemeral GPU offloading, i.e. where data in compute intense operations is copied onto the GPU before the operation is performed, after which the result is put back on CPU memory. This PR adds support in the DoRA initialization code, but the approach can be applied in a number of places: when the size of the data compared to the time to perform the operation on CPU memory is heavily time dominant, using ephemeral transfers has a fairly small VRAM overhead (depending on the size of the model/adapter) with orders of magnitude speed-up in certain operations. For example, a Llama3-8B DoRA adapter with r=64 would put an overhead of 2 x (64 x 4096 x 2 + 4096 x 4096) bytes (assuming fp16), i.e. 33 MB or so. A Llama3-70B adapter with r=32 would have 2 x (32 x 8192 x 2 + 8192 x 8192) bytes =130 MB. By making use of ephemeral GPU offloading, more efficient juggling of data between GPU and CPU may become possible, i.e. where instead of always loading as much as we can onto the GPU and then endure the CPU slowness for whatever happens to not fit in there, we intentionally leave a (modest) chunk of VRAM for optimizations like these, and the end result is a much (MUCH) faster experience. --- docs/source/developer_guides/lora.md | 16 +++ .../load_with_dora.py | 103 ++++++++++++++++++ src/peft/__init__.py | 1 + src/peft/config.py | 10 +- src/peft/peft_model.py | 18 +++ src/peft/tuners/__init__.py | 2 +- src/peft/tuners/lora/__init__.py | 14 ++- src/peft/tuners/lora/config.py | 39 +++++++ src/peft/tuners/lora/dora.py | 6 +- src/peft/tuners/lora/layer.py | 15 ++- src/peft/tuners/lora/model.py | 1 + tests/test_common_gpu.py | 81 ++++++++++++++ tests/test_config.py | 27 ++++- 13 files changed, 324 insertions(+), 9 deletions(-) create mode 100644 examples/ephemeral_gpu_offloading/load_with_dora.py diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index f73da6e0a4..824cddcf86 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -122,6 +122,22 @@ from peft import LoraConfig config = LoraConfig(use_dora=True, ...) ``` +If parts of the model or the DoRA adapter are offloaded to CPU you can get a significant speedup at the cost of some temporary (ephemeral) VRAM overhead by using `ephemeral_gpu_offload=True` in `config.runtime_config`. + +```py +from peft import LoraConfig, LoraRuntimeConfig + +config = LoraConfig(use_dora=True, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=True), ...) +``` + +A `PeftModel` with a DoRA adapter can also be loaded with `ephemeral_gpu_offload=True` flag using the `from_pretrained` method as well as the `load_adapter` method. + +```py +from peft import PeftModel + +model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True) +``` + #### Caveats - DoRA only supports linear and Conv2d layers at the momement. diff --git a/examples/ephemeral_gpu_offloading/load_with_dora.py b/examples/ephemeral_gpu_offloading/load_with_dora.py new file mode 100644 index 0000000000..8429ef38b8 --- /dev/null +++ b/examples/ephemeral_gpu_offloading/load_with_dora.py @@ -0,0 +1,103 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script demonstrating the time difference loading a model with a DoRA using ephemeral GPU offloading vs doing it purely on the CPU. + +Example outputs: +$ python load_with_dora.py +--- Loading model --- +Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.03s/it] +--- Loading PeftModel --- +--- Done --- +Model loading time: 4.83s +PeftModel loading time: 28.14s +Use ephemeral GPU offloading: False + +(Note: if this was the first time you ran the script, or if your cache was cleared, the times shown above are invalid, due to the time taken to download the model and DoRA files. Just re-run the script in this case.) + +$ python load_with_dora.py --ephemeral_gpu_offload +--- Loading model --- +Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.11it/s] +--- Loading PeftModel --- +--- Done --- +Model loading time: 4.28s +PeftModel loading time: 16.59s +Use ephemeral GPU offloading: True + +(Note: if this was the first time you ran the script, or if your cache was cleared, the times shown above are invalid, due to the time taken to download the model and DoRA files. Just re-run the script in this case.) +""" + +import argparse +import time + +from huggingface_hub import snapshot_download +from transformers import AutoModelForCausalLM + +from peft import PeftModel + + +def main(): + parser = argparse.ArgumentParser(description="Load a model with DoRA using ephemeral GPU offloading") + parser.add_argument("--model", type=str, default="NousResearch/Hermes-2-Pro-Mistral-7B", help="Model to load") + parser.add_argument( + "--dora", + type=str, + default="peft-internal-testing/DoRA-Hermes-2-Pro-Mistral-7B", + help="DoRA to use", + ) + parser.add_argument("--ephemeral_gpu_offload", action="store_true", help="Use ephemeral GPU offloading") + parser.add_argument( + "--merge_model_path", type="str", help="Merge the model with the DoRA model and save to the given path" + ) + args = parser.parse_args() + + peft_model_kwargs = { + "ephemeral_gpu_offload": args.ephemeral_gpu_offload, + "max_memory": {"cpu": "256GiB"}, + "device_map": {"": "cpu"}, + } + + # Predownload + try: + snapshot_download(repo_id=args.model) + except Exception as e: + print(f"Failed to download model: {e}") + # We continue anyway as this might be e.g. a local directory or something + try: + snapshot_download(repo_id=args.dora) + except Exception as e: + print(f"Failed to download DoRA: {e}") + # We continue anyway as this might be e.g. a local directory or something + + start = time.perf_counter() + print("--- Loading model ---") + model = AutoModelForCausalLM.from_pretrained(args.model) + model_time = time.perf_counter() - start + print("--- Loading PeftModel ---") + peft_model = PeftModel.from_pretrained(model, args.dora, **peft_model_kwargs) + print("--- Done ---") + peft_model_time = time.perf_counter() - start + + print(f"Model loading time: {model_time:.2f}s") + print(f"PeftModel loading time: {peft_model_time:.2f}s") + print(f"Use ephemeral GPU offloading: {args.ephemeral_gpu_offload}") + + if args.merge_model_path is not None: + merged_model = peft_model.merge_and_unload(progressbar=True) + merged_model.save_pretrained(args.merge_model_path) + + +if __name__ == "__main__": + main() diff --git a/src/peft/__init__.py b/src/peft/__init__.py index e39e13b303..0372c48b28 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -51,6 +51,7 @@ AdaptionPromptConfig, AdaptionPromptModel, LoraConfig, + LoraRuntimeConfig, LoftQConfig, LoraModel, LoHaConfig, diff --git a/src/peft/config.py b/src/peft/config.py index 480d536bdf..9cdcb08e9a 100644 --- a/src/peft/config.py +++ b/src/peft/config.py @@ -14,6 +14,7 @@ import inspect import json import os +import warnings from dataclasses import asdict, dataclass, field from typing import Dict, Optional, Union @@ -63,7 +64,7 @@ def save_pretrained(self, save_directory: str, **kwargs) -> None: os.makedirs(save_directory, exist_ok=True) auto_mapping_dict = kwargs.pop("auto_mapping_dict", None) - output_dict = asdict(self) + output_dict = self.to_dict() # converting set type to list for key, value in output_dict.items(): if isinstance(value, set): @@ -162,6 +163,13 @@ def from_json_file(cls, path_json_file: str, **kwargs): with open(path_json_file) as file: json_object = json.load(file) + # Sanity check that config does not contain a runtime_config + if "runtime_config" in json_object: + warnings.warn( + "The configuration file contains a `runtime_config` key. This is ignored. Runtime configurations are only valid at runtime." + ) + del json_object["runtime_config"] + return json_object @classmethod diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index c92a91be79..370f871524 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -374,6 +374,7 @@ def from_pretrained( is_trainable: bool = False, config: Optional[PeftConfig] = None, autocast_adapter_dtype: bool = True, + ephemeral_gpu_offload: bool = False, **kwargs: Any, ) -> PeftModel: r""" @@ -402,6 +403,12 @@ def from_pretrained( loaded before calling `from_pretrained`. autocast_adapter_dtype (`bool`, *optional*): Whether to autocast the adapter dtype. Defaults to `True`. Only relevant for specific adapter types. + ephemeral_gpu_offload (`bool`, *optional*): + Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`. This is + useful when parts of the model and/or components (such as adapters) are kept in CPU memory until they + are needed. Rather than perform expensive operations on small data, the data is transferred to the GPU + on-demand, the operation(s) performed, and the results moved back to CPU memory. This brings a slight + momentary VRAM overhead but gives orders of magnitude speedup in certain cases. torch_device (`str`, *optional*, defaults to None): The device to load the adapter on. If `None`, the device will be inferred. kwargs: (`optional`): @@ -426,6 +433,13 @@ def from_pretrained( else: raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") + # Runtime configuration, if supported + if hasattr(config, "runtime_config"): + config.runtime_config.ephemeral_gpu_offload = ephemeral_gpu_offload + else: + if ephemeral_gpu_offload: + warnings.warn("Ephemeral GPU offloading is not supported for this model. Ignoring.") + if hasattr(model, "hf_device_map"): weight_map = dict(named_module_tensors(model, recurse=True)) @@ -984,6 +998,7 @@ def load_adapter( is_trainable: bool = False, torch_device: Optional[str] = None, autocast_adapter_dtype: bool = True, + ephemeral_gpu_offload: bool = False, **kwargs: Any, ): """ @@ -1008,6 +1023,8 @@ def load_adapter( Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect select PEFT tuners. + ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`): + Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`. kwargs: (`optional`): Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. """ @@ -1026,6 +1043,7 @@ def load_adapter( ) ].from_pretrained( model_id, + ephemeral_gpu_offload=ephemeral_gpu_offload, **hf_hub_download_kwargs, ) if peft_config.is_prompt_learning and is_trainable: diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index a8b6b75837..c5beb67493 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -18,7 +18,7 @@ # limitations under the License. from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel -from .lora import LoraConfig, LoraModel, LoftQConfig +from .lora import LoraConfig, LoraModel, LoftQConfig, LoraRuntimeConfig from .loha import LoHaConfig, LoHaModel from .lokr import LoKrConfig, LoKrModel from .ia3 import IA3Config, IA3Model diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index 2a0bce2a5f..7339e85ae8 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -14,13 +14,23 @@ from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available -from .config import LoftQConfig, LoraConfig +from .config import LoftQConfig, LoraConfig, LoraRuntimeConfig from .gptq import QuantLinear from .layer import Conv2d, Embedding, Linear, LoraLayer from .model import LoraModel -__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] +__all__ = [ + "LoraConfig", + "LoraRuntimeConfig", + "LoftQConfig", + "Conv2d", + "Embedding", + "LoraLayer", + "Linear", + "LoraModel", + "QuantLinear", +] def __getattr__(name): diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index a94165100c..3e250b928c 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -23,6 +23,32 @@ from peft.utils import PeftType +@dataclass +class LoraRuntimeConfig: + """ + This is the sub-configuration class to store the runtime configurations for the model. + + Args: + ephemeral_gpu_offload (`bool`): + Whether to use ephemeral GPU offloading for models partially kept in CPU memory. + """ + + ephemeral_gpu_offload: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use ephemeral GPU offloading for models partially kept in CPU memory. Ephemeral GPU offloading result in " + "the data involved in intense operations being momentarily copied over to the GPU, and the results copied " + "back to CPU. There is a momentary VRAM overhead, but operations are generally orders of magnitude faster " + "compared to performing them on the CPU. This is useful when parts of the model and/or components (such " + "as adapters) are kept in CPU memory until they are needed. Rather than perform expensive operations on " + "small data, the data is transferred to the GPU on-demand, the operation(s) performed, and the results " + "moved back to CPU memory. Currently only affects DoRA initialization." + ) + }, + ) + + @dataclass class LoftQConfig: """ @@ -122,6 +148,8 @@ class LoraConfig(PeftConfig): Build a new stack of layers by stacking the original model layers according to the ranges specified. This allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will all have separate LoRA adapters attached to them. + runtime_config (`LoraRuntimeConfig`): + Runtime configurations (which are not saved or restored). """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -282,6 +310,17 @@ class LoraConfig(PeftConfig): ) }, ) + runtime_config: LoraRuntimeConfig = field( + default_factory=LoraRuntimeConfig, metadata={"help": "Runtime configurations"} + ) + + def to_dict(self): + """ + Returns the configuration for your adapter model as a dictionary. Removes runtime configurations. + """ + rv = super().to_dict() + rv.pop("runtime_config") + return rv def __post_init__(self): self.peft_type = PeftType.LORA diff --git a/src/peft/tuners/lora/dora.py b/src/peft/tuners/lora/dora.py index 6cab9bc96a..859c294f10 100644 --- a/src/peft/tuners/lora/dora.py +++ b/src/peft/tuners/lora/dora.py @@ -34,7 +34,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) return weight_norm - def update_layer(self, *, base_layer, lora_A, lora_B, scaling) -> None: + def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None: # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2 dtype_is_fp16 = lora_A.dtype == torch.float16 if dtype_is_fp16: @@ -56,8 +56,10 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling) -> None: if dtype_is_fp16: lora_weight = lora_weight.half() - weight_norm = self.get_weight_norm(weight, lora_weight, scaling) + weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling) + if place_on_cpu: + weight_norm = weight_norm.to("cpu") self.weight = nn.Parameter(weight_norm, requires_grad=True) def forward(self, x, *, lora_A, lora_B, scaling, base_layer): diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 704ce0c43a..6cc8d9a15b 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -37,7 +37,7 @@ class LoraLayer(BaseTunerLayer): # All names of other parameters that may contain adapter-related parameters other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") - def __init__(self, base_layer: nn.Module, **kwargs) -> None: + def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: self.base_layer = base_layer self.r = {} self.lora_alpha = {} @@ -54,6 +54,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.use_dora: dict[str, bool] = {} self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA self._caches: dict[str, Any] = {} + self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload self.kwargs = kwargs base_layer = self.get_base_layer() @@ -251,8 +252,18 @@ def dora_init(self, adapter_name: str) -> None: dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(self, "fan_in_fan_out", False)) lora_A = self.lora_A[adapter_name].weight lora_B = self.lora_B[adapter_name].weight + place_on_cpu = self.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu") + if self.ephemeral_gpu_offload: + if lora_A.device.type == "cuda": + lora_B = lora_B.to(lora_A.device) + else: + if lora_B.device.type != "cuda": + lora_B = lora_B.to("cuda") + lora_A = lora_A.to(lora_B.device) scaling = self.scaling[adapter_name] - dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling) + dora_layer.update_layer( + base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling, place_on_cpu=place_on_cpu + ) self.lora_magnitude_vector[adapter_name] = dora_layer def _cache_store(self, key: str, value: Any) -> None: diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index f3be8d95d1..692dda7b9a 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -196,6 +196,7 @@ def _create_and_replace( "init_lora_weights": lora_config.init_lora_weights, "use_rslora": lora_config.use_rslora, "use_dora": lora_config.use_dora, + "ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload, "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), } diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index b490bd7de9..6039d7d850 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -49,6 +49,7 @@ prepare_model_for_kbit_training, ) from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.lora.config import LoraRuntimeConfig from .testing_utils import require_bitsandbytes, require_torch_gpu, require_torch_multi_gpu @@ -1083,6 +1084,86 @@ def test_8bit_dora_merging(self): assert torch.allclose(out_dora, out_unmerged, atol=atol, rtol=rtol) assert torch.allclose(out_dora, out_unloaded, atol=atol, rtol=rtol) + @require_torch_gpu + @pytest.mark.single_gpu_tests + def test_dora_ephemeral_gpu_offload(self): + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torch_dtype=torch.float32, + ).eval() + + config = LoraConfig( + r=128, + init_lora_weights=False, + use_dora=True, + runtime_config=LoraRuntimeConfig( + ephemeral_gpu_offload=True + ), # we enable this, but only to verify that it's gone later + ) + peft_model = get_peft_model(model, config).eval() + # Check that ephemeral GPU offloading is present + assert peft_model.peft_config["default"].runtime_config.ephemeral_gpu_offload + + # Save to disk + with tempfile.TemporaryDirectory() as tmp_dir: + peft_model.save_pretrained(tmp_dir) + + # Load from disk 100% on CPU without ephemeral GPU offloading + peft_model_cpu = PeftModel.from_pretrained( + model, + tmp_dir, + device_map={"": "cpu"}, + ).eval() + + # Check that ephemeral GPU offloading is absent + assert not peft_model_cpu.peft_config["default"].runtime_config.ephemeral_gpu_offload + + # Load again, with ephemeral GPU offloading enabled + peft_model_ego = PeftModel.from_pretrained( + model, + tmp_dir, + device_map={"": "cpu"}, + ephemeral_gpu_offload=True, + ).eval() + + random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) + with torch.inference_mode(): + out_peft_model_cpu = F.softmax(peft_model_cpu(random_input).logits, dim=-1) + out_peft_model_ego = F.softmax(peft_model_ego(random_input).logits, dim=-1) + + # The results should be the same + assert torch.allclose(out_peft_model_cpu, out_peft_model_ego) + + @require_torch_gpu + @require_torch_multi_gpu + @pytest.mark.multi_gpu_tests + def test_dora_ephemeral_gpu_offload_multigpu(self): + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torch_dtype=torch.float32, + ).eval() + + config = LoraConfig( + r=16, # too small and the time difference is too small + init_lora_weights=False, + use_dora=True, + runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=True), + ) + peft_model = get_peft_model(model, config).eval() + + layer = peft_model.base_model.model.model.decoder.layers[0].self_attn.v_proj + lora_A, lora_B = layer.lora_A, layer.lora_B + + possible_combinations = ["cpu", "cuda", "cuda:0", "cuda:1"] + for device_A in possible_combinations: + la = lora_A.to(device_A) + for device_B in possible_combinations: + lb = lora_B.to(device_B) + layer.lora_A, layer.lora_B = la, lb + layer.dora_init(layer.active_adapter[0]) # should not raise an error + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a CUDA GPU") @pytest.mark.single_gpu_tests diff --git a/tests/test_config.py b/tests/test_config.py index a3dfa9d182..93ff3b2b40 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import json import os import pickle import tempfile @@ -122,7 +123,16 @@ def test_from_json_file(self, config_class): with tempfile.TemporaryDirectory() as tmp_dirname: config.save_pretrained(tmp_dirname) - config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json")) + config_path = os.path.join(tmp_dirname, "adapter_config.json") + config_from_json = config_class.from_json_file(config_path) + assert config.to_dict() == config_from_json + + # Also test with a runtime_config entry -- they should be ignored, even if they + # were accidentally saved to disk + config_from_json["runtime_config"] = {"ephemeral_gpu_offload": True} + json.dump(config_from_json, open(config_path, "w")) + + config_from_json = config_class.from_json_file(config_path) assert config.to_dict() == config_from_json @parameterized.expand(ALL_CONFIG_CLASSES) @@ -152,6 +162,21 @@ def test_from_pretrained_cache_dir_remote(self): PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname) assert "models--ybelkada--test-st-lora" in os.listdir(tmp_dirname) + @parameterized.expand(ALL_CONFIG_CLASSES) + def test_save_pretrained_with_runtime_config(self, config_class): + r""" + Test if the config correctly removes runtime config when saving + """ + with tempfile.TemporaryDirectory() as tmp_dirname: + for model_name, revision in PEFT_MODELS_TO_TEST: + cfg = config_class.from_pretrained(model_name, revision=revision) + # NOTE: cfg is always a LoraConfig here, because the configuration of the loaded model was a LoRA. + # Hence we can expect a runtime_config to exist regardless of config_class. + cfg.runtime_config.ephemeral_gpu_offload = True + cfg.save_pretrained(tmp_dirname) + cfg = config_class.from_pretrained(tmp_dirname) + assert not cfg.runtime_config.ephemeral_gpu_offload + @parameterized.expand(ALL_CONFIG_CLASSES) def test_set_attributes(self, config_class): # manually set attributes and check if they are correctly written