From 4aad5ffd3aca5fc4bbe1bebac974fec12bf1dbb9 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Wed, 22 May 2024 19:21:29 +0000 Subject: [PATCH] support for transformers 4.41.0 --- llm2vec/models/bidirectional_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llm2vec/models/bidirectional_llama.py b/llm2vec/models/bidirectional_llama.py index c6eea5d..927855c 100644 --- a/llm2vec/models/bidirectional_llama.py +++ b/llm2vec/models/bidirectional_llama.py @@ -107,7 +107,7 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None, output_attentions=False): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -179,6 +179,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype