Skip to content

Commit

Permalink
Initial composition support for LoRALinear
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Nov 7, 2023
1 parent 9dfbb7d commit d344cb0
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 27 deletions.
116 changes: 90 additions & 26 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
from typing import Dict, List, Union
from typing import Dict, List, NamedTuple, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -13,9 +13,9 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.pytorch_utils import Conv1D

from ..composition import AdapterCompositionBlock
from ..composition import AdapterCompositionBlock, Average, BatchSplit, Stack
from ..configuration import LoRAConfig, ModelAdaptersConfig
from .adapter_layer_base import AdapterLayerBase
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase


class LoRA(nn.Module):
Expand Down Expand Up @@ -75,14 +75,16 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
"""Inverts the composition operation between existing and injected weights."""
return weights - added * self.scaling

def forward(self, hidden_states: torch.Tensor, input_states: torch.Tensor):
delta_w = self.lora_dropout(input_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
if hidden_states is None:
hidden_states = layer_input
hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
if self.use_gating:
gate = torch.sigmoid(self.gate(input_states))
gate = torch.sigmoid(self.gate(layer_input))
gate = torch.mean(gate, dim=1).unsqueeze(-1)
hidden_states = hidden_states * gate
else:
gate = None
hidden_states = self.com(hidden_states, delta_w, scaling=gate)

return hidden_states, gate

Expand Down Expand Up @@ -139,14 +141,18 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
"""Inverts the composition operation between existing and injected weights."""
return weights / (added * self.scaling)

def forward(self, hidden_states: torch.Tensor, input_states: torch.Tensor):
delta_w = self.lora_B.view(1, 1, -1)
def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
scaling_vector = self.lora_B.view(1, 1, -1).repeat(layer_input.shape[0], 1, 1)
if hidden_states is None:
hidden_states = scaling_vector
else:
hidden_states = hidden_states * scaling_vector
if self.use_gating:
gate = torch.sigmoid(self.gate(input_states))
gate = torch.sigmoid(self.gate(layer_input))
gate = torch.mean(gate, dim=1).unsqueeze(-1)
hidden_states = hidden_states * gate
else:
gate = None
hidden_states = self.com(hidden_states, delta_w, scaling=gate)

return hidden_states, gate

Expand Down Expand Up @@ -252,9 +258,25 @@ def get_adapter(self, adapter_name: str) -> nn.Module:
return None


class LoRALinear(LoRALayer, nn.Linear):
class LoRAState(NamedTuple):
"""Models the input and output states of a LoRA layer.
Args:
layer_input (torch.Tensor): The input states to the adapted layer.
hidden_states (Optional[torch.Tensor]):
The hidden states of the adaptation module. These can be None before passing through the first LoRA/ IA3
module.
layer_output (torch.Tensor): The output states of the original layer without adaptation.
"""

layer_input: torch.Tensor
hidden_states: Optional[torch.Tensor]
layer_output: torch.Tensor


class LoRALinear(LoRALayer, ComposableAdapterLayerBase, nn.Linear):
"""
LoRA implementation for Linear layer.
LoRA implementation for Linear layer. This layer supports composition.
Args:
fan_in_fan_out (bool, optional):
Expand All @@ -263,6 +285,9 @@ class LoRALinear(LoRALayer, nn.Linear):
"""

# TODO: enable parallel composition for LoRA
supported_compositions = [Stack, BatchSplit, Average]

def __init__(
self,
in_features: int,
Expand Down Expand Up @@ -356,29 +381,68 @@ def merge_adapter(self, name: str):
elif self.merged != name:
raise ValueError("LoRALayer already has a merged LoRA module. Please reset it first.")

def forward(self, x: torch.Tensor):
input_states = x
# result shape: <batch_size> x <seq_len> x <head_dim>
def vslice(self, state: LoRAState, slice_obj: slice) -> LoRAState:
return LoRAState(
state.layer_input[slice_obj],
state.hidden_states[slice_obj] if state.hidden_states is not None else None,
state.layer_output[slice_obj],
)

def pad_and_concat(self, states: List[LoRAState]) -> LoRAState:
return LoRAState(
torch.cat([s.layer_input for s in states], dim=0),
torch.cat([s.hidden_states for s in states], dim=0) if states[0].hidden_states is not None else None,
torch.cat([s.layer_output for s in states], dim=0),
)

def repeat(self, state: LoRAState, channels: int) -> LoRAState:
return LoRAState(
state.layer_input.repeat(channels, 1, 1),
state.hidden_states.repeat(channels, 1, 1) if state.hidden_states is not None else None,
state.layer_output.repeat(channels, 1, 1),
)

def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState:
return LoRAState(
states[0].layer_input,
torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0)
if states[0].hidden_states is not None
else None,
states[0].layer_output,
)

def compose_single(self, adapter_setup: str, state: LoRAState, lvl: int = 0) -> LoRAState:
lora = self.loras[adapter_setup]
hidden_states, gate = lora(state.hidden_states, state.layer_input)
if gate is not None:
self._store_gating_score(adapter_setup, gate)

return state._replace(hidden_states=hidden_states)

def forward(self, input_states: torch.Tensor):
weight = torch.transpose(self.weight, -2, -1) if self.fan_in_fan_out else self.weight
x = F.linear(x, weight, bias=self.bias)
# result shape: <batch_size> x <seq_len> x <head_dim>
layer_output = F.linear(input_states, weight, bias=self.bias)

if not self.merged:
adapter_setup = self.get_active_setup()
if adapter_setup is not None:
if len(adapter_setup) == 1:
lora = self.loras[adapter_setup[0]]
x, gate = lora(x, input_states)
if gate is not None:
self._store_gating_score(adapter_setup[0], gate)
else:
raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.")
state = LoRAState(input_states, None, layer_output)
state = self.compose(adapter_setup, state)
_, hidden_states, layer_output = state

last_lora = self.loras[adapter_setup.last()]
layer_output = last_lora.com(
layer_output, hidden_states, scaling=1.0
) # scaling already applied in compose

return x
return layer_output


class LoRAMergedLinear(LoRALayer, nn.Linear):
"""
LoRA implementation for merged attention layer layer.
LoRA implementation for merged attention layer, as used by some model implementations (e.g. GPT-2). This layer
currently does not support composition.
Args:
fan_in_fan_out (bool, optional):
Expand Down
16 changes: 15 additions & 1 deletion tests_adapters/composition/test_adapter_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

import adapters
from adapters import PrefixTuningConfig, SeqBnConfig
from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig
from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
from tests.test_modeling_common import ids_tensor
from transformers import BertConfig, BertForSequenceClassification
Expand Down Expand Up @@ -234,3 +234,17 @@ class PrefixTuningCompositionTest(AdapterCompositionTest):

def get_adapter_config(self):
return PrefixTuningConfig()


class LoRACompositionTest(AdapterCompositionTest):
unsupported_blocks = [Split, Fuse, Parallel]

def get_adapter_config(self):
return LoRAConfig(init_weights="bert")


class IA3CompositionTest(AdapterCompositionTest):
unsupported_blocks = [Split, Fuse, Parallel]

def get_adapter_config(self):
return IA3Config()

0 comments on commit d344cb0

Please sign in to comment.