Skip to content

Commit

Permalink
Llama adapter modules base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 5, 2024
1 parent 767e868 commit a96a49f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
match_attn_matrices_for_parallel,
)
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv
from transformers.utils import logging

from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin
Expand All @@ -42,7 +42,7 @@
logger = logging.get_logger(__name__)


class LlamaAttentionWithAdapters(nn.Module, LlamaAttentionMixin):
class LlamaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def forward(
Expand Down Expand Up @@ -165,7 +165,7 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlamaDecoderLayerWithAdapters(nn.Module, LlamaDecoderLayerMixin):
class LlamaDecoderLayerWithAdapters(LlamaDecoderLayerMixin, LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
Expand Down

0 comments on commit a96a49f

Please sign in to comment.