Skip to content

Commit

Permalink
FIX Make special LoRA inits DeepSpeed compatible (#1887)
Browse files Browse the repository at this point in the history
Resolves huggingface/accelerate#2886

Possibly resolves
#896 (comment)

Some LoRA init methods need to access the base layer weight. Getting
this access can fail or stall in distributed settings. For DeepSpeed,
the weight is now gathered before trying to access it.

Note: Without DeepSpeed, this is a no-op and should thus not have any
disadvantage. We don't have DS in our CI, so this is not tested.

I also made some small changes to OLoRA init to use
self.get_base_layer() instead of self.base_layer.
  • Loading branch information
BenjaminBossan authored Jun 26, 2024
1 parent c9b19bb commit 184beaf
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_module_weight
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.other import transpose

from .config import LoraConfig
Expand Down Expand Up @@ -121,12 +121,16 @@ def update_layer(
else:
self.scaling[adapter_name] = lora_alpha / r

# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
self.pissa_init(adapter_name, init_lora_weights)
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
self.olora_init(adapter_name)
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
elif init_lora_weights == "loftq":
self.loftq_init(adapter_name)
with gather_params_ctx(self.get_base_layer().weight):
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# call this before dora_init
Expand Down Expand Up @@ -161,13 +165,14 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
nn.init.normal_(self.lora_embedding_B[adapter_name])

def olora_init(self, adapter_name):
dtype = self.base_layer.weight.dtype
dtype = self.get_base_layer().weight.dtype
if dtype in [torch.int8, torch.uint8]:
weight_tensor = dequantize_module_weight(self.base_layer)
weight_tensor = dequantize_module_weight(self.get_base_layer())
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
weight_tensor = self.base_layer.weight
weight_tensor = self.get_base_layer().weight
else:
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")

scale_factor = self.scaling[adapter_name]
r = self.r[adapter_name]
weight_tensor = weight_tensor.to(torch.float32)
Expand Down

0 comments on commit 184beaf

Please sign in to comment.