Skip to content

Commit

Permalink
Autocast adapter weights if fp16/bf16 (#1706)
Browse files Browse the repository at this point in the history
As discussed internally, we want to automatically cast the weights of
the adapter to float32 when using float16. Float16 is not conducive to
stable training and raises errors when used with AMP.

Previously, we had to recommend to users to manually cast the weights
if they loaded the base model in float16, because PEFT would choose the
same dtype for the adapter as for the base weights. Forgetting this is a
common source of errors, so we choose to automate this.

If this causes trouble, users can prevent the behavior by passing
autocast_adapter_dtype=False to get_peft_model,
PeftModel.from_pretrained, or PeftModel.load_adapter.

This PR should be reviewed carefully, as it has the potential to break
existing code if something important was missed. We also need to add a
note for the upcoming release text about this change in behavior.
  • Loading branch information
BenjaminBossan authored May 16, 2024
1 parent 2535036 commit ae1ae20
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 35 deletions.
6 changes: 6 additions & 0 deletions docs/source/developer_guides/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ trainer = Trainer(model=peft_model, fp16=True, ...)
trainer.train()
```

<Tip>

Starting from PEFT verion v0.11.0, PEFT automatically promotes the dtype of adapter weights from `torch.float16` and `torch.bfloat16` to `torch.float32` where appropriate. To _prevent_ this behavior, you can pass `autocast_adapter_dtype=False` to [`~get_peft_model`], to [`~PeftModel.from_pretrained`], and to [`~PeftModel.load_adapter`].

</Tip>

## Bad results from a loaded PEFT model

There can be several reasons for getting a poor result from a loaded PEFT model which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue.
Expand Down
12 changes: 10 additions & 2 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_peft_model(
peft_config: PeftConfig,
adapter_name: str = "default",
mixed: bool = False,
autocast_adapter_dtype: bool = True,
revision: Optional[str] = None,
) -> PeftModel | PeftMixedModel:
"""
Expand All @@ -136,6 +137,10 @@ def get_peft_model(
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
mixed (`bool`, `optional`, defaults to `False`):
Whether to allow mixing different (compatible) adapter types.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
revision (`str`, `optional`, defaults to `main`):
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
Expand All @@ -154,14 +159,17 @@ def get_peft_model(
peft_config.revision = revision

if mixed:
# note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
return PeftModel(model, peft_config, adapter_name=adapter_name)
return PeftModel(model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)

if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
)


def inject_adapter_in_model(
Expand Down
107 changes: 89 additions & 18 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft.
peft_config ([`PeftConfig`]): The configuration of the Peft model.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
**Attributes**:
- **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft.
Expand All @@ -118,7 +122,13 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
in the base model if using [`PromptLearningConfig`].
"""

def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None:
def __init__(
self,
model: PreTrainedModel,
peft_config: PeftConfig,
adapter_name: str = "default",
autocast_adapter_dtype: bool = True,
) -> None:
super().__init__()
self.modules_to_save = None
self.active_adapter = adapter_name
Expand All @@ -138,6 +148,11 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name
self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
self.set_additional_trainable_modules(peft_config, adapter_name)

if hasattr(self.base_model, "_cast_adapter_dtype"):
self.base_model._cast_adapter_dtype(
adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
)

if getattr(model, "is_gradient_checkpointing", True):
model = self._prepare_model_for_gradient_checkpointing(model)

Expand Down Expand Up @@ -335,6 +350,7 @@ def from_pretrained(
adapter_name: str = "default",
is_trainable: bool = False,
config: Optional[PeftConfig] = None,
autocast_adapter_dtype: bool = True,
**kwargs: Any,
) -> PeftModel:
r"""
Expand All @@ -361,6 +377,8 @@ def from_pretrained(
The configuration object to use instead of an automatically loaded configuration. This configuration
object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
loaded before calling `from_pretrained`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Only relevant for specific adapter types.
kwargs: (`optional`):
Additional keyword arguments passed along to the specific PEFT configuration class.
"""
Expand Down Expand Up @@ -424,10 +442,15 @@ def from_pretrained(
config.inference_mode = not is_trainable

if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
model = cls(model, config, adapter_name)
model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)
else:
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name)
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](
model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
)
model.load_adapter(
model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs
)

return model

def _setup_prompt_encoder(self, adapter_name: str):
Expand Down Expand Up @@ -935,6 +958,7 @@ def load_adapter(
adapter_name: str,
is_trainable: bool = False,
torch_device: Optional[str] = None,
autocast_adapter_dtype: bool = True,
**kwargs: Any,
):
"""
Expand All @@ -955,6 +979,10 @@ def load_adapter(
used for inference.
torch_device (`str`, *optional*, defaults to None):
The device to load the adapter on. If `None`, the device will be inferred.
autocast_adapter_dtype (`bool`, *optional*, defaults to `True`):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter
weights using float16 and bfloat16 to float32, as this is typically required for stable training, and
only affect select PEFT tuners.
kwargs: (`optional`):
Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
"""
Expand Down Expand Up @@ -1034,6 +1062,11 @@ def load_adapter(
remove_hook_from_submodules(self.prompt_encoder)
add_hook_to_module(self.get_base_model(), hook)

if hasattr(self.base_model, "_cast_adapter_dtype"):
self.base_model._cast_adapter_dtype(
adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
)

# Set model in evaluation mode to deactivate Dropout modules by default
if not is_trainable:
self.eval()
Expand Down Expand Up @@ -1133,6 +1166,11 @@ class PeftModelForSequenceClassification(PeftModel):
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
**Attributes**:
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
Expand Down Expand Up @@ -1166,8 +1204,10 @@ class PeftModelForSequenceClassification(PeftModel):
```
"""

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)
def __init__(
self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
) -> None:
super().__init__(model, peft_config, adapter_name, **kwargs)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
Expand Down Expand Up @@ -1361,7 +1401,11 @@ class PeftModelForCausalLM(PeftModel):
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
Example:
Expand Down Expand Up @@ -1391,8 +1435,10 @@ class PeftModelForCausalLM(PeftModel):
```
"""

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)
def __init__(
self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
) -> None:
super().__init__(model, peft_config, adapter_name, **kwargs)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

def forward(
Expand Down Expand Up @@ -1566,7 +1612,11 @@ class PeftModelForSeq2SeqLM(PeftModel):
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
Example:
Expand Down Expand Up @@ -1595,8 +1645,10 @@ class PeftModelForSeq2SeqLM(PeftModel):
```
"""

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)
def __init__(
self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
) -> None:
super().__init__(model, peft_config, adapter_name, **kwargs)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
self.base_model._prepare_encoder_decoder_kwargs_for_generation
Expand Down Expand Up @@ -1820,6 +1872,11 @@ class PeftModelForTokenClassification(PeftModel):
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
**Attributes**:
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
Expand Down Expand Up @@ -1853,8 +1910,10 @@ class PeftModelForTokenClassification(PeftModel):
```
"""

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)
def __init__(
self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default", **kwargs
) -> None:
super().__init__(model, peft_config, adapter_name, **kwargs)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
Expand Down Expand Up @@ -2032,6 +2091,11 @@ class PeftModelForQuestionAnswering(PeftModel):
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
**Attributes**:
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
Expand Down Expand Up @@ -2063,8 +2127,10 @@ class PeftModelForQuestionAnswering(PeftModel):
```
"""

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
super().__init__(model, peft_config, adapter_name)
def __init__(
self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
) -> None:
super().__init__(model, peft_config, adapter_name, **kwargs)

qa_module_names = ["qa_outputs"]
if self.modules_to_save is None:
Expand Down Expand Up @@ -2265,6 +2331,11 @@ class PeftModelForFeatureExtraction(PeftModel):
Args:
model ([`~transformers.PreTrainedModel`]): Base transformer model.
peft_config ([`PeftConfig`]): Peft config.
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
**Attributes**:
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
Expand Down Expand Up @@ -2293,8 +2364,8 @@ class PeftModelForFeatureExtraction(PeftModel):
```
"""

def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default"):
super().__init__(model, peft_config, adapter_name)
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs):
super().__init__(model, peft_config, adapter_name, **kwargs)

def forward(
self,
Expand Down
42 changes: 41 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,52 @@ def _check_new_adapter_config(self, config: PeftConfig) -> None:
"""
pass

def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
"""
A helper method to cast the adapter weights to the correct dtype.
Currently, this only upcasts float16 and bfloat16 to float32.
Args:
adapter_name (`str`):
The adapter name.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`.
"""
if not autocast_adapter_dtype:
return

dtypes_to_convert_to_fp32 = {torch.float16, torch.bfloat16}

for module in self.model.modules():
if not isinstance(module, BaseTunerLayer):
continue

for submodule in module.modules():
if not isinstance(submodule, (nn.ModuleDict, nn.ParameterDict)):
continue

if adapter_name not in submodule:
continue

if isinstance(submodule[adapter_name], nn.Parameter):
if submodule[adapter_name].dtype in dtypes_to_convert_to_fp32:
submodule[adapter_name].data = submodule[adapter_name].data.to(torch.float32)
continue

for param in submodule[adapter_name].parameters():
if param.dtype in dtypes_to_convert_to_fp32:
param.data = param.data.to(torch.float32)

def _check_merge_allowed(self):
"""Helper method to check whether the adapter can be merged.
Raise a ValueError if it is not possible to merge the adapter with the given configuration.
"""
pass

def inject_adapter(self, model: nn.Module, adapter_name: str):
def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
r"""
Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the
hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed.
Expand All @@ -323,6 +361,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
The model to be tuned.
adapter_name (`str`):
The adapter name.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`.
"""
peft_config = self.peft_config[adapter_name]
# Note: If possible, all checks should be performed *at the start of this method*.
Expand Down
Loading

0 comments on commit ae1ae20

Please sign in to comment.