diff --git a/.gitignore b/.gitignore index 6873ea6..22b9bb3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ build/ dist/ -*.egg-info \ No newline at end of file +*.egg-info +**/__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 3ef46bc..cabfa02 100644 --- a/README.md +++ b/README.md @@ -29,57 +29,38 @@ pip install -e . LLM2Vec class is a wrapper on top of HuggingFace models to support sequence encoding and pooling operations. The steps below showcase an example on how to use the library. ### Preparing the model -Here, we first initialize the model and apply MNTP-trained LoRA weights on top. After merging the model with MNTP weights, we can -- either load the unsupervised-trained LoRA weights (trained with SimCSE objective and wiki corpus) -- or we can load the model with supervised-trained LoRA weights (trained with contrastive learning and public E5 data). +Initializing LLM2Vec model using pretrained LLMs is straightforward. The `from_pretrained` method of LLM2Vec takes a base model identifier/path and an optional PEFT model identifier/path. All HuggingFace model loading arguments can be passed to `from_pretrained` method (make sure the `llm2vec` package version is `>=0.1.3`). + +Here, we first initialize the Mistral MNTP base model and load the unsupervised-trained LoRA weights (trained with SimCSE objective and wiki corpus). ```python import torch -from transformers import AutoTokenizer, AutoModel, AutoConfig -from peft import PeftModel - +from llm2vec import LLM2Vec -# Loading base MNTP model, along with custom code that enables bidirectional connections in decoder-only LLMs -tokenizer = AutoTokenizer.from_pretrained( - "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp" -) -config = AutoConfig.from_pretrained( - "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", trust_remote_code=True -) -model = AutoModel.from_pretrained( +l2v = LLM2Vec.from_pretrained( "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", - trust_remote_code=True, - config=config, - torch_dtype=torch.bfloat16, + peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse", device_map="cuda" if torch.cuda.is_available() else "cpu", + torch_dtype=torch.bfloat16, ) -model = PeftModel.from_pretrained( - model, - "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", -) -model = model.merge_and_unload() # This can take several minutes on cpu - -# Loading unsupervised-trained LoRA weights. This loads the trained LoRA weights on top of MNTP model. Hence the final weights are -- Base model + MNTP (LoRA) + SimCSE (LoRA). -model = PeftModel.from_pretrained( - model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse" -) - -# Or loading supervised-trained LoRA weights -model = PeftModel.from_pretrained( - model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised" -) - ``` -### Applying `LLM2Vec` wrapper -Then, we define our LLM2Vec encoder model as follows: +We can also load the model with supervised-trained LoRA weights (trained with contrastive learning and public E5 data) by changing the `peft_model_name_or_path`. ```python +import torch from llm2vec import LLM2Vec -l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512) +l2v = LLM2Vec.from_pretrained( + "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised", + device_map="cuda" if torch.cuda.is_available() else "cpu", + torch_dtype=torch.bfloat16, +) ``` +By default the LLM2Vec model uses the `mean` pooling strategy. You can change the pooling strategy by passing the `pooling_mode` argument to the `from_pretrained` method. Similarly, you can change the maximum sequence length by passing the `max_length` argument (default is 512). + ### Inference This model now returns the text embedding for any input in the form of `[[instruction1, text1], [instruction2, text2]]` or `[text1, text2]`. While training, we provide instructions for both sentences in symmetric tasks, and only for for queries in asymmetric tasks. diff --git a/llm2vec/__init__.py b/llm2vec/__init__.py index d637d2a..3c631c4 100644 --- a/llm2vec/__init__.py +++ b/llm2vec/__init__.py @@ -1 +1 @@ -from .llm2vec import LLM2Vec \ No newline at end of file +from .llm2vec import LLM2Vec diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index 412d150..d665270 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -1,15 +1,31 @@ +import json import logging +import os from functools import partial +from typing import Dict, List, Optional, Union + import numpy as np import torch -from torch import nn, Tensor, device import torch.multiprocessing as mp -from typing import Dict, List, Union, Optional +from peft import PeftModel +from torch import Tensor, device, nn from tqdm.autonotebook import trange -from transformers import AutoModel, AutoTokenizer, LlamaConfig, MistralConfig +from transformers import ( + AutoModel, + AutoConfig, + AutoTokenizer, + LlamaConfig, + MistralConfig, +) + +from .models import ( + MistralBiModel, + LlamaBiModel, +) logger = logging.getLogger(__name__) + def batch_to_device(batch, target_device: device): """ send a pytorch batch to a device (CPU/GPU) @@ -28,7 +44,7 @@ def __init__( pooling_mode: str = "mean", max_length: int = 512, doc_max_length: int = 400, - skip_instruction: bool = True + skip_instruction: bool = True, ): super().__init__() self.model = model @@ -38,14 +54,83 @@ def __init__( self.max_length = max_length self.doc_max_length = doc_max_length + @classmethod + def _get_model_class(cls, config_class_name): + if config_class_name == "MistralConfig": + return MistralBiModel + elif config_class_name == "LlamaConfig": + return LlamaBiModel + else: + raise ValueError(f"{config_class_name} is not supported yet.") + + @classmethod + def from_pretrained( + cls, + base_model_name_or_path, + peft_model_name_or_path=None, + **kwargs, + ): + # pop out encoder args + keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] + encoder_args = { + key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None + } + + tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + config = AutoConfig.from_pretrained(base_model_name_or_path) + config_class_name = config.__class__.__name__ + + model_class = cls._get_model_class(config_class_name) + model = model_class.from_pretrained(base_model_name_or_path, **kwargs) + + # For special case where config.json and adapter weights are in the same directory + if hasattr(model, "peft_config"): + model = PeftModel.from_pretrained( + model, + base_model_name_or_path, + ) + model = model.merge_and_unload() + + if peft_model_name_or_path is not None: + model = PeftModel.from_pretrained( + model, + peft_model_name_or_path, + ) + + config = {} + config_addr = ( + peft_model_name_or_path + if peft_model_name_or_path is not None + else base_model_name_or_path + ) + if os.path.exists(f"{config_addr}/llm2vec_config.json"): + with open(f"{config_addr}/llm2vec_config.json", "r") as fIn: + llm2vec_config = json.load(fIn) + config.update(llm2vec_config) + + for key, value in encoder_args.items(): + config[key] = value + + return cls(model=model, tokenizer=tokenizer, **config) + def prepare_for_tokenization(self, text): def _is_instruct(name): - return ("chat" in name.lower()) or ("instruct" in name.lower()) or ("sharegpt" in name.lower()) - + return ( + ("chat" in name.lower()) + or ("instruct" in name.lower()) + or ("sharegpt" in name.lower()) + ) + if _is_instruct(self.model.config._name_or_path): - text = '[INST] ' + text.strip() + ' [/INST]' - if (isinstance(self.model.config, LlamaConfig) or isinstance(self.model.config, MistralConfig)) and self.pooling_mode == "eos_token": - text = text.strip() + ' ' + text = "[INST] " + text.strip() + " [/INST]" + if ( + isinstance(self.model.config, LlamaConfig) + or isinstance(self.model.config, MistralConfig) + ) and self.pooling_mode == "eos_token": + text = text.strip() + " " return text def tokenize(self, texts): @@ -55,27 +140,47 @@ def tokenize(self, texts): t = text.split("!@#$%^&*()") texts_2.append(t[1]) original_texts.append("".join(t)) - - original = self.tokenizer(original_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length) + + original = self.tokenizer( + original_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + ) embed_mask = None for t_i, t in enumerate(texts_2): - ids = self.tokenizer([t], return_tensors='pt', padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False) + ids = self.tokenizer( + [t], + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) if embed_mask is None: e_m = torch.zeros_like(original["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: - e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) + e_m[-len(ids["input_ids"][0]) :] = torch.ones( + len(ids["input_ids"][0]) + ) embed_mask = e_m.unsqueeze(0) else: e_m = torch.zeros_like(original["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: - e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) + e_m[-len(ids["input_ids"][0]) :] = torch.ones( + len(ids["input_ids"][0]) + ) embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) original["embed_mask"] = embed_mask return original def _skip_instruction(self, sentence_feature): - assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape + assert ( + sentence_feature["attention_mask"].shape + == sentence_feature["embed_mask"].shape + ) sentence_feature["attention_mask"] = sentence_feature["embed_mask"] def forward(self, sentence_feature: Dict[str, Tensor]): @@ -84,63 +189,93 @@ def forward(self, sentence_feature: Dict[str, Tensor]): embed_mask = sentence_feature.pop("embed_mask") reps = self.model(**sentence_feature) sentence_feature["embed_mask"] = embed_mask - + return self.get_pooling(sentence_feature, reps.last_hidden_state) - + def get_pooling(self, features, last_hidden_states): # All models padded from left - assert self.tokenizer.padding_side == 'left', "Pooling modes are implemented for padding from left." + assert ( + self.tokenizer.padding_side == "left" + ), "Pooling modes are implemented for padding from left." if self.skip_instruction: self._skip_instruction(features) seq_lengths = features["attention_mask"].sum(dim=-1) if self.pooling_mode == "mean": - return torch.stack([last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)], dim=0) + return torch.stack( + [ + last_hidden_states[i, -length:, :].mean(dim=0) + for i, length in enumerate(seq_lengths) + ], + dim=0, + ) elif self.pooling_mode == "weighted_mean": bs, l, _ = last_hidden_states.shape complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) for i, seq_l in enumerate(seq_lengths): if seq_l > 0: complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 - complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9) + complete_weights[i] /= torch.clamp( + complete_weights[i].sum(), min=1e-9 + ) return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": return last_hidden_states[:, -1] elif self.pooling_mode == "bos_token": - return last_hidden_states[features["input_ids"]==self.tokenizer.bos_token_id] + return last_hidden_states[ + features["input_ids"] == self.tokenizer.bos_token_id + ] else: raise ValueError(f"{self.pooling_mode} is not implemented yet.") - def _convert_to_str(self, instruction, text): - tokenized_q = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False) + def _convert_to_str(self, instruction, text): + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) tokenized_q_length = len(tokenized_q["input_ids"][0]) while tokenized_q_length > self.doc_max_length: reduction_ratio = self.doc_max_length / tokenized_q_length reduced_length = int(len(text.split()) * reduction_ratio) text = " ".join(text.split()[:reduced_length]) - tokenized_q = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False) + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) tokenized_q_length = len(tokenized_q["input_ids"][0]) - + return f"{instruction.strip()} !@#$%^&*(){text}" - def encode(self, sentences: Union[str, List[str]], - batch_size: int = 32, - show_progress_bar: bool = True, - convert_to_numpy: bool = False, - convert_to_tensor: bool = False, + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = True, + convert_to_numpy: bool = False, + convert_to_tensor: bool = False, ): - if isinstance(sentences[0],str) and isinstance(sentences[-1],int): + if isinstance(sentences[0], str) and isinstance(sentences[-1], int): sentences = [sentences] # required for MEDI version of MTEB - if isinstance(sentences[0],str): + if isinstance(sentences[0], str): sentences = [[""] + [sentence] for sentence in sentences] - + concatenated_input_texts = [] for sentence in sentences: assert isinstance(sentence[0], str) assert isinstance(sentence[1], str) - concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1])) + concatenated_input_texts.append( + self._convert_to_str(sentence[0], sentence[1]) + ) sentences = concatenated_input_texts - + self.eval() show_progress_bar = True @@ -154,24 +289,39 @@ def encode(self, sentences: Union[str, List[str]], if torch.cuda.device_count() <= 1: device = "cuda" if torch.cuda.is_available() else "cpu" self.to(device) - for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): - sentences_batch = sentences_sorted[start_index:start_index + batch_size] - embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy) + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=not show_progress_bar, + ): + sentences_batch = sentences_sorted[ + start_index : start_index + batch_size + ] + embeddings = self._encode( + sentences_batch, device=device, convert_to_numpy=convert_to_numpy + ) all_embeddings.append(embeddings) else: num_proc = torch.cuda.device_count() cuda_compatible_multiprocess = mp.get_context("spawn") with cuda_compatible_multiprocess.Pool(num_proc) as p: - sentences_batches = [sentences_sorted[start_index:start_index + batch_size] - for start_index in trange(0, len(sentences), batch_size)] - for result in p.map(partial(self._encode, - device=None, - convert_to_numpy=convert_to_numpy, - multiprocessing=True), - sentences_batches): + sentences_batches = [ + sentences_sorted[start_index : start_index + batch_size] + for start_index in trange(0, len(sentences), batch_size) + ] + for result in p.map( + partial( + self._encode, + device=None, + convert_to_numpy=convert_to_numpy, + multiprocessing=True, + ), + sentences_batches, + ): all_embeddings.append(result) - all_embeddings = torch.cat(all_embeddings, dim=0) all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] all_embeddings = all_embeddings.to(torch.float32) @@ -179,19 +329,38 @@ def encode(self, sentences: Union[str, List[str]], all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) return all_embeddings - def save(self, output_path): + def save(self, output_path, merge_before_save=False, save_config=True): + if merge_before_save and isinstance(self.model, PeftModel): + self.model = self.model.merge_and_unload() + # Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1 + if hasattr(self.model, "_hf_peft_config_loaded"): + self.model._hf_peft_config_loaded = False + self.model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) - def _encode(self, sentences_batch, device, convert_to_numpy, multiprocessing=False): + llm2vec_config = { + "pooling_mode": self.pooling_mode, + "max_length": self.max_length, + "doc_max_length": self.doc_max_length, + "skip_instruction": self.skip_instruction, + } + + if save_config: + os.makedirs(output_path, exist_ok=True) + with open(f"{output_path}/llm2vec_config.json", "w") as fOut: + json.dump(llm2vec_config, fOut, indent=4) + def _encode(self, sentences_batch, device, convert_to_numpy, multiprocessing=False): if multiprocessing: rank = mp.current_process()._identity[0] if device is None and torch.cuda.is_available(): device = f"cuda:{rank % torch.cuda.device_count()}" self.to(device) - features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch]) + features = self.tokenize( + [self.prepare_for_tokenization(sentence) for sentence in sentences_batch] + ) features = batch_to_device(features, device) with torch.no_grad(): @@ -207,16 +376,24 @@ def _text_length(self, text: Union[List[int], List[List[int]]]): a list of ints (which means a single tokenized text), or a tuple of list of ints (representing several text inputs to the model). """ - if isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0: #Single text, list of ints, or empty + if ( + isinstance(text, str) + or (isinstance(text, list) and isinstance(text[0], int)) + or len(text) == 0 + ): # Single text, list of ints, or empty return len(text) - if isinstance(text, dict): #{key: value} case + if isinstance(text, dict): # {key: value} case return len(next(iter(text.values()))) - elif not hasattr(text, '__len__'): #Object has no len() method + elif not hasattr(text, "__len__"): # Object has no len() method return 1 else: return sum([len(t) for t in text]) - + def resize_token_embeddings( - self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None - ) -> nn.Embedding: - return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + return self.model.resize_token_embeddings( + new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of + ) diff --git a/llm2vec/models/__init__.py b/llm2vec/models/__init__.py new file mode 100644 index 0000000..dbbe3f8 --- /dev/null +++ b/llm2vec/models/__init__.py @@ -0,0 +1,2 @@ +from .bidirectional_mistral import MistralBiModel, MistralBiForMNTP +from .bidirectional_llama import LlamaBiModel, LlamaBiForMNTP diff --git a/llm2vec/models/attn_mask_utils.py b/llm2vec/models/attn_mask_utils.py new file mode 100644 index 0000000..3c74570 --- /dev/null +++ b/llm2vec/models/attn_mask_utils.py @@ -0,0 +1,160 @@ +from typing import List, Optional, Tuple, Union +import torch +from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter( + is_causal=False, sliding_window=sliding_window + ) # is_causal=True in original implementation + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + key_value_length=key_value_length, + dtype=inputs_embeds.dtype, + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], + input_shape[-1], + key_value_length, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter( + is_causal=False, sliding_window=sliding_window + ) # is_causal=True in original implementation + + key_value_length = input_shape[-1] + past_key_values_length + batch_size, query_length = input_shape + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + if attention_mask is not None: + # 4d mask is passed through + if len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + return attention_mask + + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1: + # For query_length == 1, causal attention and bi-directional attention are the same. + attention_mask = None + elif key_value_length == query_length: + attention_mask = None + else: + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + pass + elif query_length > 1 and key_value_length != query_length: + # See the comment above (https://github.com/pytorch/pytorch/issues/108108). + # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`. + attention_mask = True + elif is_tracing: + raise ValueError( + 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' + ) + + if attention_mask is None: + expanded_4d_mask = None + elif attention_mask is True: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], + input_shape[-1], + key_value_length, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min + ) + + return expanded_4d_mask diff --git a/llm2vec/models/bidirectional_llama.py b/llm2vec/models/bidirectional_llama.py new file mode 100644 index 0000000..028bbcf --- /dev/null +++ b/llm2vec/models/bidirectional_llama.py @@ -0,0 +1,190 @@ +import torch + +from packaging import version +import importlib.metadata + +from transformers import LlamaModel, LlamaForCausalLM, LlamaPreTrainedModel, LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaAttention, + LlamaFlashAttention2, + LlamaSdpaAttention, + LlamaMLP, + LlamaRMSNorm, +) + +from torch import nn +from transformers.utils import logging + +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_38(): + if not _is_package_available("transformers"): + return False + + return version.parse(importlib.metadata.version("transformers")) >= version.parse( + "4.38.0" + ) + + +class ModifiedLlamaAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedLlamaFlashAttention2(LlamaFlashAttention2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedLlamaSdpaAttention(LlamaSdpaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +LLAMA_ATTENTION_CLASSES = { + "eager": ModifiedLlamaAttention, + "flash_attention_2": ModifiedLlamaFlashAttention2, + "sdpa": ModifiedLlamaSdpaAttention, +} + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + +class LlamaBiModel(LlamaModel): + def __init__(self, config: LlamaConfig): + if not is_transformers_attn_greater_or_equal_4_38(): + raise ValueError( + "The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.38.0" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + ModifiedLlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if hasattr( + getattr(self.layers[0], "self_attn", {}), "past_key_value" + ): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[-1] + 1 + ) + + causal_mask = torch.zeros( + (sequence_length, target_length), dtype=dtype, device=device + ) # in original implementation - torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + # Commenting out next 2 lines to disable causal masking + # if sequence_length != 1: + # causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand( + input_tensor.shape[0], 1, -1, -1 + ) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ + :, None, None, : + ].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[ + ..., :mask_length + ].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], + : mask_shape[1], + offset : mask_shape[2] + offset, + : mask_shape[3], + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + +class LlamaBiForMNTP(LlamaForCausalLM): + def __init__(self, config): + LlamaPreTrainedModel.__init__(self, config) + self.model = LlamaBiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() diff --git a/llm2vec/models/bidirectional_mistral.py b/llm2vec/models/bidirectional_mistral.py new file mode 100644 index 0000000..6bed116 --- /dev/null +++ b/llm2vec/models/bidirectional_mistral.py @@ -0,0 +1,281 @@ +from typing import List, Optional, Tuple, Union +import torch + +from transformers import ( + MistralModel, + MistralPreTrainedModel, + MistralForCausalLM, + MistralConfig, +) +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.cache_utils import Cache, DynamicCache +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralRMSNorm, + MistralAttention, + MistralFlashAttention2, + MistralSdpaAttention, + MistralMLP, +) +from torch import nn +from transformers.utils import logging +from .attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) + +logger = logging.get_logger(__name__) + + +class ModifiedMistralAttention(MistralAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedMistralFlashAttention2(MistralFlashAttention2): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedMistralSdpaAttention(MistralSdpaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +MISTRAL_ATTENTION_CLASSES = { + "eager": ModifiedMistralAttention, + "flash_attention_2": ModifiedMistralFlashAttention2, + "sdpa": ModifiedMistralSdpaAttention, +} + + +class ModifiedMistralDecoderLayer(MistralDecoderLayer): + def __init__(self, config: MistralConfig, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + +class MistralBiModel(MistralModel): + def __init__(self, config: MistralConfig): + MistralPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + ModifiedMistralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._attn_implementation = config._attn_implementation + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + # Copied from forward() in transformers.models.mistral.modeling_mistral.MistralModel + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if ( + attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + elif self._attn_implementation == "sdpa" and not output_attentions: + # The original implementation is by-passed, see attn_mask_utils.py + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MistralBiForMNTP(MistralForCausalLM): + def __init__(self, config): + MistralPreTrainedModel.__init__(self, config) + self.model = MistralBiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() diff --git a/llm2vec/version.py b/llm2vec/version.py index a68927d..3dc1f76 100644 --- a/llm2vec/version.py +++ b/llm2vec/version.py @@ -1 +1 @@ -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.1.0" diff --git a/setup.py b/setup.py index 6296f2d..fe2babc 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "tqdm", "torch", "peft", - "transformers>=4.38.1" + "transformers>=4.39.1" ], classifiers=[ "Programming Language :: Python :: 3",