diff --git a/llm2vec/models/bidirectional_llama.py b/llm2vec/models/bidirectional_llama.py index c6eea5d..927855c 100644 --- a/llm2vec/models/bidirectional_llama.py +++ b/llm2vec/models/bidirectional_llama.py @@ -107,7 +107,7 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None, output_attentions=False): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -179,6 +179,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype