diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 990a4607b1..c1fc25c7bf 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -33,7 +33,7 @@ match_attn_matrices_for_parallel, ) from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin @@ -42,7 +42,7 @@ logger = logging.get_logger(__name__) -class LlamaAttentionWithAdapters(nn.Module, LlamaAttentionMixin): +class LlamaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def forward( @@ -165,7 +165,7 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayerWithAdapters(nn.Module, LlamaDecoderLayerMixin): +class LlamaDecoderLayerWithAdapters(LlamaDecoderLayerMixin, LlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor,