diff --git a/pecos/xmr/reranker/README.md b/pecos/xmr/reranker/README.md new file mode 100644 index 0000000..f2a11d6 --- /dev/null +++ b/pecos/xmr/reranker/README.md @@ -0,0 +1,204 @@ +# PECOS XMR Reranker + +This is a reranker for the PECOS XMR model. It is based on huggingface's transformers library. The reranker can be run in both +single process and distributed mode. It is based on the paper [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319). + +## How to run +### Single process +To run the reranker in single process mode, you can use the following command: + +```bash +python -m pecos.xmr.reranker.train --config_json_path +``` + +### Distributed mode +To run the reranker in distributed mode, you can use the following command to initialize the distributed configuration: +```bash +accelerate config +``` + +Then you can run the reranker using the following command: +```bash +accelerate launch -m pecos.xmr.reranker.train --config_json_path +``` + +### Predictions +To run the reranker in prediction mode, you can use the following command: +```bash +python -m pecos.xmr.reranker.predict --config_json_path +``` + +## Configuration file + +### Training +Here is an example of the configuration file for training: +```json +{ + "train_params": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams" + }, + "target_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/target", + "input_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/input", + "label_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/label", + "training_args": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.trainer###RankLlamaTrainer.TrainingArgs" + }, + "learning_rate": 1e-4, + "output_dir": "./ds_model", + "per_device_train_batch_size": 8, + "gradient_accumulation_steps": 8, + "max_steps": -1, + "logging_strategy": "steps", + "logging_first_step": false, + "logging_steps": 10, + "save_strategy": "steps", + "save_steps": 50, + "save_total_limit": 5, + "seed": 42, + "data_seed": 42, + "bf16": true, + "dataloader_num_workers": 2, + "dataloader_prefetch_factor": 10, + "gradient_checkpointing": true, + "train_group_size": 16 + } + }, + "model_params": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams" + }, + "encoder_args": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.model###CrossEncoder.Config" + }, + "model_shortcut": "meta-llama/Llama-2-7b-hf", + "model_init_kwargs": {}, + "model_modifier": { + "modifier_type": "peft", + "config_type": "LoraConfig" , + "config": { + "r": 8, + "lora_alpha": 64, + "target_modules": ["q_proj", "v_proj"], + "modules_to_save": ["score", "classifier"], + "lora_dropout": 0.1 + } + } + }, + "positive_passage_no_shuffle": false, + "negative_passage_no_shuffle": false, + "rerank_max_len": 196, + "query_prefix": "query: ", + "passage_prefix": "document: ", + "inp_id_col": "inp_id", + "lbl_idxs_col": "ret_idxs", + "score_col": "rel", + "keyword_col_name": "keywords", + "content_col_names": ["title", "contents"], + "append_eos_token": false, + "pad_to_multiple_of": 16 + } +} +``` + +### Prediction +Following is the example of the configuration file for prediction: +```json +{ + "model_name_or_path": "/tmp/pecosdev/ds_model", + "target_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/target", + "input_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/input", + "label_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/label", + "output_dir": "/tmp/xmrout", + "per_device_eval_batch_size": 512, + "dataloader_num_workers": 1, + "dataloader_prefetch_factor": 10, + "rerank_max_len": 196, + "query_prefix": "query: ", + "passage_prefix": "document: ", + "inp_id_col": "inp_id", + "lbl_id_col": "lbl_id", + "keyword_col_name": "keywords", + "content_col_names": ["title", "contents"], + "append_eos_token": false, + "pad_to_multiple_of": 8, + "device": "cuda", + "model_init_kwargs": { + "device_map": "auto" + } +} +``` + +## Data Schema +The column names for the data schema are configurable through the json configuration file. Following +are the various schemas that are supported by the reranker: + +(1) Learning Target Schema +``` +# +-----------------+---------------+-----------------------+ +# | Column Name | Data Type | Description | +# +-----------------+---------------+-----------------------+ +# | inp_id | int32 | input id | +# | lbl_id | array | an array of label_id | +# | score | array | an array of rel_score | +# +-----------------+---------------+-----------------------+ +``` + +(2) Input Feature Store Schema +``` +# +-----------------+---------------+-----------------------+ +# | Column Name | Data Type | Description | +# +-----------------+---------------+-----------------------+ +# | inp_id | int32 | input id | +# | keywords | string | keyword string | +# +-----------------+---------------+-----------------------+ +``` + +(3) Label Feature Store Schema + +The label feature store supports variable number of columns. The column names +can be provided in the configuration file. +``` +# +-----------------+---------------+-----------------------+ +# | Column Name | Data Type | Description | +# +-----------------+---------------+-----------------------+ +# | lbl_id | int32 | input id | +# | title | string | title text | +# | content | string | content string | +# | ... | string | content string | +# +-----------------+---------------+-----------------------+ +``` + +(4) Evaluation Schema +``` +# +-----------------+---------------+-----------------------+ +# | Column Name | Data Type | Description | +# +-----------------+---------------+-----------------------+ +# | inp_id | int32 | input id | +# | lbl_id | int32 | label_id | +# +-----------------+---------------+-----------------------+ +``` + +(5) Evaluation Input Feature Store Schema +``` +# +-----------------+---------------+-----------------------+ +# | Column Name | Data Type | Description | +# +-----------------+---------------+-----------------------+ +# | inp_id | int32 | input id | +# | keywords | string | keyword string | +# +-----------------+---------------+-----------------------+ +``` + +(6) Evaluation Label Feature Store Schema +``` +# +-----------------+---------------+-----------------------+ +# | Column Name | Data Type | Description | +# +-----------------+---------------+-----------------------+ +# | lbl_id | int32 | input id | +# | title | string | title text | +# | content | string | content string | +# | ... | string | content string | +# +-----------------+---------------+-----------------------+ +``` diff --git a/pecos/xmr/reranker/data_utils.py b/pecos/xmr/reranker/data_utils.py new file mode 100644 index 0000000..2ec652a --- /dev/null +++ b/pecos/xmr/reranker/data_utils.py @@ -0,0 +1,164 @@ +import os +import random +from collections import OrderedDict +from typing import List, Tuple, Callable + +import numpy as np +import pyarrow.parquet as pq +from datasets import load_dataset + +import pecos + + +class RankingDataUtils(pecos.BaseClass): + """ + Utility class for handling data related tasks + """ + + @classmethod + def remap_ordereddict(cls, od: OrderedDict, keymap_fn: Callable): + """ + Function to remap the keys of an ordered Dictionary + Args: + od: The ordered dictionary to remap + keymap_fn: The function to map the keys + """ + new_od = OrderedDict() + for k, v in od.items(): + new_od[keymap_fn(k)] = v + return new_od + + @classmethod + def _format_sample( + cls, + inp_text: str, + lbl_contents: List[str], + inp_prefix: str = "...", + passage_prefix: str = "...", + content_sep=" ", + ) -> str: + """ + Function to convert the text fields into a formatted string + that the model understands. + Args: + inp_text: The input text + lbl_contents: The list of content fields + inp_prefix: The input prefix + passage_prefix: The passage prefix + content_sep: The separator between the content fields + Returns: The formatted string + """ + # Convention from rankllama is to replace hyphens in the title + lbl_contents[0] = lbl_contents[0].replace("-", " ").strip() + return f"{inp_prefix} {inp_text} {passage_prefix} {content_sep.join(lbl_contents)}".strip() + + @classmethod + def _create_sample( + cls, + inp_id: int, + ret_idxs: List[int], + scores: List[float], + table_stores, + train_group_size: int, + inp_prefix: str, + passage_prefix: str, + keyword_col_name: str, + content_col_names: List[str], + content_sep, + ) -> Tuple[List[str], List[float]]: + """ + Function to create a sample for training. + Args: + inp_id: The input id + ret_idxs: The retrieved indices + scores: Scores for the retrieved indices + table_stores: Dictionary of table stores for input and label data + train_group_size: The number of passages used to train for each query + inp_prefix: The input prefix + passage_prefix: The passage prefix + keyword_col_name: The column name for the query text + content_col_names: The column names for the content fields + content_sep: The separator between the content fields + Returns: A tuple of formatted samples and scores + """ + qid = inp_id + pidxs = ret_idxs + + input_store = table_stores["input"] + label_store = table_stores["label"] + + # get the values of the query + query = input_store[qid][keyword_col_name] + mean_score = np.mean(scores) + + # get idxs for positive items + pos_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x > mean_score] + neg_idxs = [(x, pid) for x, pid in zip(scores, pidxs) if x <= mean_score] + random.shuffle(pos_idxs) + random.shuffle(neg_idxs) + + num_positives = train_group_size // 2 + + all_selections = pos_idxs[:num_positives] + num_positives = len(all_selections) + num_negatives = train_group_size - num_positives + all_selections.extend(neg_idxs[:num_negatives]) + + if len(all_selections) < train_group_size: + all_selections.extend( + random.choices(neg_idxs, k=train_group_size - len(all_selections)) + ) + + all_scores = [s for s, _ in all_selections] + all_pids = [pid for _, pid in all_selections] + + # get the values for the retrieved items + ret_info = [label_store[i] for i in all_pids] + + formated_pair = [] + for info in ret_info: + formated_pair.append( + cls._format_sample( + query, + [info[c] for c in content_col_names], + inp_prefix, + passage_prefix, + content_sep, + ) + ) + return formated_pair, all_scores + + @classmethod + def get_parquet_rows(cls, folder_path: str) -> int: + """ + Returns the count of rows in parquet files by reading the + metadata + Args: + folder_path: The folder containing the parquet files + Returns: The count of rows in the parquet files + """ + file_list = os.listdir(folder_path) + file_list = [os.path.join(folder_path, x) for x in file_list] + cumulative_rowcount = sum([pq.read_metadata(fp).num_rows for fp in file_list]) + + return cumulative_rowcount + + @classmethod + def get_sorted_data_files(cls, filenames: List[str], idx_colname) -> List[str]: + """ + Returns the list of files sorted by the id in the first row of each file + Args: + filenames: The list of filenames + idx_colname: The column name of the id + Returns: The sorted list of filenames + """ + # Load the datasets in streaming format and read the first id + fn_ordered = [] # this containes tuples with (idx, filename) + for fn in filenames: + tmp_ds = load_dataset("parquet", data_files=fn, streaming=True, split="train") + row = next(iter(tmp_ds.take(1))) + fn_ordered.append((row[idx_colname], fn)) + del tmp_ds + fn_ordered = sorted(fn_ordered, key=lambda x: x[0]) + + return [x[1] for x in fn_ordered] diff --git a/pecos/xmr/reranker/model.py b/pecos/xmr/reranker/model.py new file mode 100644 index 0000000..6e9212d --- /dev/null +++ b/pecos/xmr/reranker/model.py @@ -0,0 +1,627 @@ +import dataclasses as dc +import json +import logging +import os +from dataclasses import dataclass, field +from functools import partial +from typing import Dict, List, Tuple, Any, Optional, Union + +import peft +import torch +from datasets import IterableDataset, Dataset +from torch.utils.data import DataLoader +from peft import AutoPeftModelForSequenceClassification, get_peft_model +from peft.config import PeftConfig +from peft.mixed_model import PeftMixedModel +from peft.peft_model import PeftModel +from transformers import AutoModelForSequenceClassification, PreTrainedModel +from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig + +import pecos +from pecos.xmr.reranker.trainer import RankLlamaTrainer, PARAM_FILENAME +from .data_utils import RankingDataUtils + +logger = logging.getLogger(__name__) + + +class CrossEncoderConfig(PretrainedConfig): + """ + The configuration class for the cross encoder model. This class contains the model shortcut, model modifier and + model initialization arguments for the model. The model shortcut is the name of the huggingface model. The + `model_modifier` is the configuration of the modifier (e.g. PEFT) and the `model_init_kwargs` are the arguments + for the model. + """ + + model_type = "reranker_crossencoder" + + def __init__( + self, + model_shortcut: str = "", + model_modifier: Dict = {}, + model_init_kwargs: dict = {}, + **kwargs, + ): + """ + Initialize the cross encoder configuration + Args: + model_shortcut: The model shortcut for the huggingface model + model_modifier: The model modifier configuration (e.g. PEFT) + model_init_kwargs: The model initialization arguments. These are the arguments for the huggingface model + """ + super().__init__(**kwargs) + + self.model_shortcut = model_shortcut + self.model_modifier = model_modifier + self.model_init_kwargs = model_init_kwargs + + +class CrossEncoder(PreTrainedModel): + """ + The cross encoder model for ranking tasks (retrieval-based). This model is used for training and evaluation. + It is a wrapper around the huggingface transformer model. + """ + + TRANSFORMER_CLS = AutoModelForSequenceClassification + TRANSFORMER_PEFT_CLS = AutoPeftModelForSequenceClassification + + @dataclass + class Config(pecos.BaseParams): + """Encoder configuration + model_shortcut (str): the model shortcut of the HuggingFace model + model_init_kwargs (dict): model initialization kwargs + model_modifier (dict): model modifier configuration + """ + + model_shortcut: str = "" + model_init_kwargs: dict = dc.field(default_factory=lambda: dict()) + model_modifier: dict = dc.field(default_factory=lambda: dict()) + + config_class = CrossEncoderConfig + + def __init__(self, config: CrossEncoderConfig): + """ + Initialize the cross encoder model + Args: + config: The configuration for the cross encoder + """ + super().__init__(config) + base_model = AutoModelForSequenceClassification.from_pretrained( + config.model_shortcut, num_labels=1, **config.model_init_kwargs + ) + base_model.config.pad_token_id = ( + 0 if base_model.config.pad_token_id is None else base_model.config.pad_token_id + ) + self.hf_model = base_model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + **kwargs, + ): + """ + Load the model from the pretrained model name or path. Override the `from_pretrained` method of the + `PreTrainedModel` class. + """ + is_local = os.path.isdir(pretrained_model_name_or_path) + param_folder = pretrained_model_name_or_path + + def super_return(): + return PreTrainedModel.from_pretrained( + pretrained_model_name_or_path, + *model_args, + config, + cache_dir, + ignore_mismatched_sizes, + force_download, + local_files_only, + token, + revision, + use_safetensors, + **kwargs, + ) + + if not is_local: + raise NotImplementedError(f"{cls} can only load local models") + + with open(os.path.join(param_folder, PARAM_FILENAME), "r") as param_file: + params = json.load(param_file) + + xe_config = CrossEncoder.Config.from_dict(params["model_params"]["encoder_args"]) + xe_config = CrossEncoderConfig(**xe_config.to_dict()) + for k, v in kwargs.items(): + xe_config.model_init_kwargs[k] = v + model = CrossEncoder(xe_config) + + try: + if xe_config.model_modifier["modifier_type"] == "peft": + model = PeftModel.from_pretrained(model, param_folder) + else: + super_return() + except KeyError: + logger.info("No peft configuration found") + + return model + + def forward(self, *args, **kwargs): + """ + Returns the forward output of the huggingface model + """ + return self.hf_model(*args, **kwargs) + + def gradient_checkpointing_enable(self, **kwargs): + """ + Enable gradient checkpointing for the model + """ + + try: + if self.config.model_modifier["modifier_type"] == "peft": + self.hf_model.enable_input_require_grads() + except KeyError: + pass + self.hf_model.gradient_checkpointing_enable(**kwargs) + + +class RankingModel(pecos.BaseClass): + """ + The ranking model class for training and evaluation of the cross encoder model. This class is used for training + and evaluation of the cross encoder model. It is a wrapper around the cross encoder model. It also contains the + parameters for the model. The model can be used for training and evaluation. + """ + + @dataclass + class TrainParams(pecos.BaseParams): + """ + The training parameters for the ranking model. + Args: + training_args (RankLlamaTrainer.TrainingArgs): The training arguments for the model + target_data_folder (str): The path to the target data folder + input_data_folder (str): The path to the input data folder + label_data_folder (str): The path to the label data folder + """ + + training_args: RankLlamaTrainer.TrainingArgs + target_data_folder: str = field( + metadata={ + "help": "Path to folder containing target parquet files (inp_id, [lbl_id], [rel_val])" + } + ) + input_data_folder: str = field( + metadata={"help": "Path to folder containing input parquet files (inp_id, keywords)"} + ) + label_data_folder: str = field( + metadata={ + "help": "Path to folder containing label parquet files (lbl_id, title, contents)" + } + ) + + @dataclass + class ModelParams(pecos.BaseParams): + """ + The parameters for the ranking model. This class contains the data, encoder and training arguments for the model. + """ + + encoder_args: CrossEncoder.Config + + positive_passage_no_shuffle: bool = False + negative_passage_no_shuffle: bool = False + rerank_max_len: int = 20000 + query_prefix: str = "query: " + passage_prefix: str = "document: " + inp_id_col: str = "inp_id" + lbl_idxs_col: str = "ret_idxs" + score_col: str = "rel" + keyword_col_name: str = "keywords" + content_col_names: List[str] = field(default_factory=lambda: ["title", "contents"]) + content_sep: str = " " + append_eos_token: bool = False + pad_to_multiple_of: Optional[int] = 8 + + @dataclass + class EvalParams(pecos.BaseParams): + """ + Evaluation parameters + """ + + model_name_or_path: str + target_data_folder: str + input_data_folder: str + label_data_folder: str + output_dir: str + output_file_prefix: str = "output_" + output_file_suffix: str = "" + per_device_eval_batch_size: int = 128 + dataloader_num_workers: int = 2 + dataloader_prefetch_factor: int = 10 + rerank_max_len: int = 196 + query_prefix: str = "query: " + passage_prefix: str = "document: " + inp_id_col: str = "inp_id" + lbl_id_col: str = "lbl_id" + keyword_col_name: str = "keywords" + content_col_names: List[str] = field(default_factory=lambda: ["title", "contents"]) + content_sep: str = " " + append_eos_token: bool = False + pad_to_multiple_of: int = 16 + bf16: bool = True + device: str = "cuda" + model_init_kwargs: dict = dc.field(default_factory=lambda: dict()) + + def __init__( + self, + encoder: Union[CrossEncoder, PeftModel, PeftMixedModel], + tokenizer: AutoTokenizer, + model_params: ModelParams, + train_params: Optional[TrainParams] = None, + ): + """ + Initialize the ranking model. The model contains the encoder, tokenizer, model parameters and training parameters. + Args: + encoder (Union[CrossEncoder, PeftModel, PeftMixedModel]): The encoder model + tokenizer (AutoTokenizer): The tokenizer for the model + model_params (RankingModel.ModelParams): The model parameters + train_params (Optional[RankingModel.TrainParams]): The training parameters + """ + self.tokenizer = tokenizer + self.cross_encoder = encoder + + self.model_params = self.ModelParams.from_dict(model_params) + self.train_params = self.TrainParams.from_dict(train_params) if train_params else None + + @classmethod + def get_modified_model(cls, model: CrossEncoder, mod_config: Dict): + """ + Takes a pretrained Huggingface model and modifies it to include new features. Currently, the `modifier_type` + supported by this method is limited to the `peft` package. + + Args: + model (CrossEncoder): A PreTrainedModel from the transformers package. + mod_config (Dict): A dictionary containing the configuration for the model modifier. + Returns: The modified model + """ + if mod_config["modifier_type"] == "peft": + config_type = getattr(peft, mod_config["config_type"]) + peft_config: PeftConfig = config_type(**mod_config["config"]) + + model = get_peft_model(model, peft_config) + + return model + else: + logger.warn("Using model without modifiers (e.g. LoRA)") + return model + + @classmethod + def init_model(cls, model_params: ModelParams, train_params: TrainParams): + """Initiate a model with training parameters + + Args: + model_params (RankingModel.ModelParams): the model parameters + train_params (RankingModel.TrainParams): the training parameters + Returns: + An instance of RankingModel + """ + hf_trainer_args = train_params.training_args + if hf_trainer_args.local_rank > 0: + torch.distributed.barrier() + + config = model_params.encoder_args.to_dict() + config = CrossEncoderConfig(**config) + encoder = CrossEncoder( + config=config, + ) + + if hf_trainer_args.bf16: + encoder = encoder.bfloat16() + + if config.model_modifier: + encoder = cls.get_modified_model(model=encoder, mod_config=config.model_modifier) + + tokenizer = AutoTokenizer.from_pretrained( + model_params.encoder_args.model_shortcut, + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.unk_token_id + tokenizer.padding_side = "right" + + if torch.distributed.is_initialized(): + if hf_trainer_args.local_rank == 0: + torch.distributed.barrier() + + return cls(encoder, tokenizer, model_params, train_params=train_params) + + @classmethod + def _collate_sharded( + cls, + tokenizer: Union[PreTrainedTokenizer, AutoTokenizer], + model_params: ModelParams, + train_params: TrainParams, + table_stores: Dict[str, Dataset], + data: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """ + Collate function for training. Tokenizes the input and return features and returns the collated batch. + Args: + tokenizer: The huggingface tokenizer + params: The model parameters + table_stores: The table stores for the input and label data + data: The data to be collated + Returns: The collated batch in the form of a dictionary with input and scores + """ + fts_w_scores = [] + for s in data: + inp_id = s[model_params.inp_id_col] + retr_idxs = s[model_params.lbl_idxs_col] + scores = s[model_params.score_col] + + fts_w_scores.append( + RankingDataUtils._create_sample( + inp_id, + retr_idxs, + scores, + table_stores, + train_params.training_args.train_group_size, + model_params.query_prefix, + model_params.passage_prefix, + model_params.keyword_col_name, + model_params.content_col_names, + model_params.content_sep, + ) + ) + + return cls._collate(tokenizer, model_params, fts_w_scores) + + @classmethod + def _collate( + cls, + tokenizer: Union[PreTrainedTokenizer, AutoTokenizer], + model_params: ModelParams, + features_w_scores: List[Tuple[Any, Any]], + ): + """ + Collate function for training. Tokenizes the input and return features and returns the collated batch. + Args: + tokenizer: The huggerface tokenizer + params: The model parameters + features_w_scores: Tuple of features list and scores list + Returns: The collated batch in the form of a dictionary with input and scores + """ + features = [f for f, _ in features_w_scores] + scores = [s for _, s in features_w_scores] + + all_pairs = [] + for pairs in features: + all_pairs.extend(pairs) + + tokenized_pairs = tokenizer( + all_pairs, + padding=False, + truncation=True, + max_length=( + model_params.rerank_max_len - 1 + if model_params.append_eos_token + else model_params.rerank_max_len + ), + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=True, + ) + + if model_params.append_eos_token: + tokenized_pairs["input_ids"] = [ + p + [tokenizer.eos_token_id] for p in tokenized_pairs["input_ids"] + ] + + pairs_collated = tokenizer.pad( + tokenized_pairs, + padding=True, + pad_to_multiple_of=model_params.pad_to_multiple_of, + return_attention_mask=True, + return_tensors="pt", + ) + # NOTE: Here scores has to be flattened, otherwise the huggingface trainer will distribute it + # incorrectly across devices in distributed training. + m_scores = torch.tensor(scores, dtype=torch.float).flatten() + + return {"input": pairs_collated, "scores": m_scores} + + @classmethod + def _collate_sharded_eval( + cls, + tokenizer: Union[PreTrainedTokenizer, AutoTokenizer], + eval_params: EvalParams, + table_stores: Dict[str, Dataset], + data: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """ + Collate function for evaluation. Tokenizes the input and return features and returns the collated batch. + Args: + tokenizer: The huggingface tokenizer + eval_params: The evaluation parameters + table_stores: The table stores for the input and label datasets + data: The data to be collated + Returns: The collated batch in the form of a dictionary with the tokenized texts together with the input and label indices + """ + fts = [] + inp_idxs = [] + lbl_idxs = [] + for s in data: + inp_id = s[eval_params.inp_id_col] + retr_id = s[eval_params.lbl_id_col] + inp_idxs.append(inp_id) + lbl_idxs.append(retr_id) + + fts.append( + RankingDataUtils._format_sample( + table_stores["input"][inp_id][eval_params.keyword_col_name], + [table_stores["label"][retr_id][col] for col in eval_params.content_col_names], + eval_params.query_prefix, + eval_params.passage_prefix, + eval_params.content_sep, + ) + ) + + return cls._collate_eval(tokenizer, eval_params, fts, inp_idxs, lbl_idxs) + + @classmethod + def _collate_eval( + cls, + tokenizer: Union[PreTrainedTokenizer, AutoTokenizer], + eval_params: EvalParams, + features: List[str], + inp_idxs: List[int], + lbl_idxs: List[int], + ): + """ + Collate function for training. Tokenizes the input and return features and returns the collated batch. + Args: + tokenizer: The huggerface tokenizer + eval_params: The evaluation parameters + features: The list of features + inp_idxs: The list of input indices + lbl_idxs: The list of label indices + Returns: The collated batch in the form of a dictionary with tokenized input, input indices and label indices + """ + + all_pairs = [] + for pairs in features: + all_pairs.append(pairs) + + tokenized_pairs = tokenizer( + all_pairs, + padding=False, + truncation=True, + max_length=( + eval_params.rerank_max_len - 1 + if eval_params.append_eos_token + else eval_params.rerank_max_len + ), + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=True, + ) + + if eval_params.append_eos_token: + tokenized_pairs["input_ids"] = [ + p + [tokenizer.eos_token_id] for p in tokenized_pairs["input_ids"] + ] + + pairs_collated = tokenizer.pad( + tokenized_pairs, + padding=True, + pad_to_multiple_of=eval_params.pad_to_multiple_of, + return_attention_mask=True, + return_tensors="pt", + ) + + return { + "inp_idxs": inp_idxs, + "lbl_idxs": lbl_idxs, + "inputs": pairs_collated, + } + + @classmethod + def predict( + cls, + eval_dataset: IterableDataset, + table_stores: Dict[str, Dataset], + eval_params: EvalParams, + tokenizer: AutoTokenizer, + model: CrossEncoder, + ): + model.eval() + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "right" + + eval_dataloader = DataLoader( + eval_dataset, + batch_size=eval_params.per_device_eval_batch_size, + # Ensure that at least one worker is creating batches + # parallel to the model compute + num_workers=max(eval_params.dataloader_num_workers, 1), + # To ensure efficiency we prefetch samples in parallel + prefetch_factor=eval_params.dataloader_prefetch_factor, + collate_fn=partial(cls._collate_sharded_eval, tokenizer, eval_params, table_stores), + ) + + all_results = [] + for batch in eval_dataloader: + with torch.inference_mode(): + inp_ids = batch["inp_idxs"] + lbl_ids = batch["lbl_idxs"] + inputs = batch["inputs"].to(eval_params.device) + model_output = model(**inputs).logits + scores = model_output.cpu().detach().float().numpy() + for i in range(len(scores)): + inp_id = inp_ids[i] + ret_id = lbl_ids[i] + score = scores[i][0] + all_results.append((inp_id, ret_id, score)) + + return all_results + + @classmethod + def train( + cls, + train_dataset: IterableDataset, + table_stores: Dict[str, Dataset], + model_params: ModelParams, + train_params: TrainParams, + ): + """ + Train the ranking model + Args: + train_dataset: The training dataset + table_stores: The table stores for the input and label data + model_params: The model parameters + train_params: The training parameters + """ + training_args = train_params.training_args + # we need to have 'unused' columns to maintain information about + # group and scores coming from the collator + training_args.remove_unused_columns = False + outer_model = cls.init_model(model_params, train_params) + inner_model = outer_model.cross_encoder + + logger.info("Model loading...") + if torch.distributed.is_initialized(): + torch.distributed.barrier() + else: + # NOTE This is needed for the case where the program is run in a single process mode + if training_args.bf16 and not torch.distributed.is_initialized(): + inner_model = inner_model.bfloat16() + + logger.info("=" * 50) + logger.info( + f"Memory used by model: {round(inner_model.get_memory_footprint() / 1024 / 1024 / 1024, 2)} GB" + ) + + trainer = RankLlamaTrainer( + model=inner_model, + args=training_args, + train_dataset=train_dataset, + data_collator=partial( + cls._collate_sharded, + outer_model.tokenizer, + model_params, + train_params, + table_stores, + ), + outer_model=outer_model, + ) + + # NOTE: in the huggingface trainers `_prepare_input` method, the inputs are converted from + # mps device to cpu. To run on Apple Silicon, the method should be overridden. It is not + # clear if training is supported for Apple Silicon devices. + trainer.train() + trainer.save_model() diff --git a/pecos/xmr/reranker/predict.py b/pecos/xmr/reranker/predict.py new file mode 100644 index 0000000..89de357 --- /dev/null +++ b/pecos/xmr/reranker/predict.py @@ -0,0 +1,77 @@ +import json +import argparse +import os +from datasets import load_dataset +from .data_utils import RankingDataUtils +from .model import RankingModel, CrossEncoder +from transformers import AutoTokenizer +from tqdm import tqdm +import pandas as pd + + +def main(config_json_path: str): + with open(config_json_path, "r") as fin: + params = json.load(fin) + + params = RankingModel.EvalParams.from_dict(params) + + # helper function for getting the list of filepaths in a folder + def construct_file_list(folder): + return [os.path.join(folder, x) for x in os.listdir(folder)] + + input_files = construct_file_list(params.input_data_folder) + label_files = construct_file_list(params.label_data_folder) + target_files = construct_file_list(params.target_data_folder) + inp_id_col = params.inp_id_col + lbl_id_col = params.lbl_id_col + + input_files, label_files = RankingDataUtils.get_sorted_data_files( + input_files, inp_id_col + ), RankingDataUtils.get_sorted_data_files(label_files, lbl_id_col) + + table_stores = { + "input": load_dataset("parquet", data_files=input_files, split="train"), + "label": load_dataset("parquet", data_files=label_files, split="train"), + } + + # Create output folder if it does not exist + if not os.path.exists(params.output_dir): + os.makedirs(params.output_dir) + + model = CrossEncoder.from_pretrained(params.model_name_or_path, **params.model_init_kwargs) + if params.bf16: + model = model.bfloat16() + + tokenizer = AutoTokenizer.from_pretrained(model.config.model_shortcut) + + for target_file in tqdm(target_files): + target_filename = os.path.basename(target_file) + target_shucked_filename = ".".join(target_filename.split(".")[:-1]) + out_pre = params.output_file_prefix + out_suff = params.output_file_suffix + ext = ".parquet" + save_filename = out_pre + target_shucked_filename + out_suff + ext + + eval_dataset = load_dataset( + "parquet", data_files=[target_file], streaming=True, split="train" + ) + + results = RankingModel.predict( + eval_dataset, + table_stores, + params, + tokenizer=tokenizer, + model=model, + ) + + # Save the results to a parquet with (inp_id, lbl_id, score) columns + # `results` is a list of tuple (inp_id, lbl_id, score) + df = pd.DataFrame(results, columns=["inp_id", "lbl_id", "score"]) + df.to_parquet(os.path.join(params.output_dir, save_filename)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_json_path", type=str, required=True) + args = parser.parse_args() + main(args.config_json_path) diff --git a/pecos/xmr/reranker/train.py b/pecos/xmr/reranker/train.py new file mode 100644 index 0000000..aaa1f01 --- /dev/null +++ b/pecos/xmr/reranker/train.py @@ -0,0 +1,94 @@ +import argparse +import json +import logging +import os + +import datasets.distributed +import torch +from datasets import load_dataset +from transformers import set_seed + +from .data_utils import RankingDataUtils +from .model import RankingModel + +logger = logging.getLogger(__name__) + +""" +Usage: +```bash +python -m pecos.xmr.reranker.train --config_json_path config.json +``` +""" + + +def main(config_json_path: str): + """ + Args: + config_json_path: JSON configuration for running the training + """ + # parse train_params and model_params from json + with open(config_json_path, "r") as fin: + param = json.load(fin) + model_params: RankingModel.ModelParams = RankingModel.ModelParams.from_dict( + param.get("model_params", None), + recursive=True, + ) + + train_params: RankingModel.TrainParams = RankingModel.TrainParams.from_dict( + param.get("train_params", None), + recursive=True, + ) + + set_seed(train_params.training_args.seed) + + # helper function for getting the list of filepaths in a folder + def construct_file_list(folder): + return [os.path.join(folder, x) for x in os.listdir(folder)] + + input_files = construct_file_list(train_params.input_data_folder) + label_files = construct_file_list(train_params.label_data_folder) + input_files, label_files = RankingDataUtils.get_sorted_data_files( + input_files, "inp_id" + ), RankingDataUtils.get_sorted_data_files(label_files, "lbl_id") + + train_dataset = load_dataset( + "parquet", data_dir=train_params.target_data_folder, streaming=True, split="train" + ) + train_dataset_rows = RankingDataUtils.get_parquet_rows(train_params.target_data_folder) + logger.info(f"total target inputs: {train_dataset_rows}") + + training_args = train_params.training_args + # set the max_steps in accordance with the number of num_rows + if training_args.max_steps <= 0: + ws = training_args.world_size + bs = training_args.per_device_train_batch_size + gas = training_args.gradient_accumulation_steps + batch_size = ws * bs * gas + max_steps = train_dataset_rows // batch_size + training_args.max_steps = max_steps + logger.info(f"total batch size: {batch_size}, train steps: {max_steps}") + else: + logger.info(f"max steps: {training_args.max_steps}") + + table_stores = { + "input": load_dataset("parquet", data_files=input_files, split="train"), + "label": load_dataset("parquet", data_files=label_files, split="train"), + } + + train_dataset = train_dataset.shuffle(buffer_size=5000, seed=training_args.data_seed) + train_dataset = datasets.distributed.split_dataset_by_node( + train_dataset, training_args.local_rank, training_args.world_size + ) + + logger.info("Waiting for main process to perform the mapping") + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + RankingModel.train(train_dataset, table_stores, model_params, train_params) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_json_path", type=str, required=True) + args = parser.parse_args() + main(args.config_json_path) diff --git a/pecos/xmr/reranker/trainer.py b/pecos/xmr/reranker/trainer.py new file mode 100644 index 0000000..9ba2809 --- /dev/null +++ b/pecos/xmr/reranker/trainer.py @@ -0,0 +1,127 @@ +import copy +import json +import logging +import os +from dataclasses import dataclass +from typing import Optional, Any, Tuple, Dict + +import torch +from torch.utils.data import DataLoader +from transformers import Trainer, TrainingArguments, HfArgumentParser + +import pecos + +PARAM_FILENAME: str = "param.json" + +logger = logging.getLogger(__name__) + + +class RankLlamaTrainer(Trainer, pecos.BaseClass): + """ + Trainer class for the RankLlama model. This class extends the Trainer class. + """ + + loss_fn = torch.nn.CrossEntropyLoss(reduction="mean") + outer_model = None + + def __init__(self, *args, **kwargs): + self.outer_model = kwargs.pop("outer_model") + super(RankLlamaTrainer, self).__init__(*args, **kwargs) + + @dataclass + class TrainingArgs(TrainingArguments, pecos.BaseParams): + train_group_size: int = 8 + + @classmethod + def from_dict(cls, param=None): + if param is None: + return cls() + elif isinstance(param, cls): + return copy.deepcopy(param) + elif isinstance(param, dict): + parser = HfArgumentParser(cls) + return parser.parse_dict(param, allow_extra_keys=True)[0] + raise ValueError(f"{param} is not a valid parameter dictionary for {cls.name}") + + def to_dict(self, with_meta=True): + d = super().to_dict() + return self.append_meta(d) if with_meta else d + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training dataloader. This function is called by the Trainer class. + """ + prefetch_factor = self.args.dataloader_prefetch_factor + prefetch_factor = prefetch_factor if prefetch_factor else 10 + return DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + # Ensure that at least one worker is creating batches + # parallel to the model compute + num_workers=max(self.args.dataloader_num_workers, 1), + # To ensure efficiency we prefetch samples in parallel + prefetch_factor=prefetch_factor, + collate_fn=self.data_collator, + ) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + """ + Save the model and tokenizer to the output directory. Makes sure the huggingface model is saved correctly. + Args: + output_dir: The output directory to save the model and tokenizer. + state_dict: The state dictionary to save + """ + # If we are executing this function, we are the process zero, so we don't check for that. + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving to {output_dir}") + + outer_model: Any = self.outer_model + super()._save(output_dir, state_dict) + + # save the config + param = { + "model": outer_model.__class__.__name__, + "model_params": outer_model.model_params.to_dict(), + "train_params": outer_model.train_params.to_dict(), + } + + output_dir = output_dir if output_dir is not None else self.args.output_dir + + param = outer_model.append_meta(param) + with open(os.path.join(output_dir, PARAM_FILENAME), "w", encoding="utf-8") as f: + f.write(json.dumps(param, indent=True)) + + def _prepare_inputs(self, inputs): + """ + Prepare the inputs for the model. This function is called by the Trainer class. Converts the inputs to mps + tensors if available. + """ + super_inputs = super(RankLlamaTrainer, self)._prepare_inputs(inputs) + if torch.backends.mps.is_available(): + super_inputs = {k: v.to("mps") for k, v in super_inputs.items()} + return super_inputs + + def compute_loss( + self, model, inputs: Dict[str, Any], return_outputs: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the loss for the model. This function is called by the Trainer class. + Args: + model: The model to compute the loss for + inputs: The inputs to the model + return_outputs: Whether to return the outputs + """ + self.args: RankLlamaTrainer.TrainingArgs + train_group_size = self.args.train_group_size + if not train_group_size: + raise NotImplementedError("Cannot perform ranking without train group") + gt_scores = inputs["scores"].reshape(-1, train_group_size) + ranker_logits = model(**inputs["input"], return_dict=True).logits + batch_size = gt_scores.shape[0] + + grouped_logits = ranker_logits.view(batch_size, -1) + assert grouped_logits.shape == gt_scores.shape + loss = self.loss_fn(grouped_logits, gt_scores) + + return (loss, ranker_logits) if return_outputs else loss