diff --git a/examples/msmarco-rankllama/README.md b/examples/msmarco-rankllama/README.md new file mode 100644 index 0000000..9014366 --- /dev/null +++ b/examples/msmarco-rankllama/README.md @@ -0,0 +1,33 @@ +# PECOS XMR Reranker on MS-Marco Dataset + +This is an example of PECOS-based RankingModel that reproduced the [RankLlaMA paper](https://arxiv.org/abs/2310.08319). + +## How to run + +### Training +```bash +torchrun --nnodes 1 --nproc-per-node 8 \ + -m pecos.xmr.reranker.train \ + --config_json_path ./msmarco_qwen2-7B.train.json +``` + +### Predictions +```bash +python -m pecos.xmr.reranker.predict \ + --config_json_path ./msmarco_qwen2-7B.pred.json +``` + +## Evaluation +We first convert the predictions from parquet to TREC format: +```python +python -u parquet_to_trec_eval.py -i inference_outputs/ms_marco/qwen2-7B -o inference_outputs/ms_marco/qwen2-7B.pred.trec +``` + +We then follow [Pyserini]() evaluation protocol to eval the NDCG@10, +and you should see the results like: +```python +python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 dl19-passage inference_outputs/ms_marco/qwen2-7B.pred.trec + +Results: +ndcg_cut_10 all 0.7619 +``` diff --git a/examples/msmarco-rankllama/msmarco_qwen2-7B.pred.json b/examples/msmarco-rankllama/msmarco_qwen2-7B.pred.json new file mode 100644 index 0000000..b7b8bae --- /dev/null +++ b/examples/msmarco-rankllama/msmarco_qwen2-7B.pred.json @@ -0,0 +1,21 @@ +{ + "target_data_folder": "./datasets/ms_marco/eval_aux/target", + "input_data_folder": "./datasets/ms_marco/eval_aux/input", + "label_data_folder": "./datasets/ms_marco/eval_aux/label", + "model_path": "./models/ms_marco/qwen2-7B/", + "output_dir": "./inference_outputs/ms_marco/qwen2-7B/", + "per_device_eval_batch_size": 1024, + "dataloader_num_workers": 1, + "dataloader_prefetch_factor": 10, + "rerank_max_len": 196, + "query_prefix": "query: ", + "passage_prefix": "document: ", + "inp_id_col": "inp_id", + "lbl_id_col": "lbl_id", + "inp_id_orig_col": "inp_id_orig", + "lbl_id_orig_col": "lbl_id_orig", + "keyword_col_name": "keywords", + "content_col_names": ["title", "contents"], + "append_eos_token": false, + "pad_to_multiple_of": 8 +} diff --git a/examples/msmarco-rankllama/msmarco_qwen2-7B.train.json b/examples/msmarco-rankllama/msmarco_qwen2-7B.train.json new file mode 100644 index 0000000..9f2063c --- /dev/null +++ b/examples/msmarco-rankllama/msmarco_qwen2-7B.train.json @@ -0,0 +1,140 @@ +{ + "train_params": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams" + }, + "target_data_folder": "./datasets/ms_marco/train/target", + "input_data_folder": "./datasets/ms_marco/train/input", + "label_data_folder": "./datasets/ms_marco/train/label", + "hf_trainer_args": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.trainer###RankingTrainer.TrainingArgs" + }, + "output_dir": "./models/ms_marco/qwen2-7B", + "ddp_find_unused_parameters": false, + "loss_fn": "listwise", + "loss_alpha": 1.0, + "group_size": 16, + "per_device_train_batch_size": 6, + "gradient_accumulation_steps": 8, + "disable_tqdm": false, + "logging_strategy": "steps", + "logging_first_step": false, + "learning_rate": 1e-4, + "max_steps": 1500, + "save_steps": 50, + "logging_steps": 10, + "save_strategy": "steps", + "save_total_limit": 5, + "seed": 42, + "data_seed": 42, + "bf16": true, + "dataloader_num_workers": 2, + "dataloader_prefetch_factor": 10, + "gradient_checkpointing": true, + "deepseed": { + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 1e6, + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 10, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 10, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto", + "torch_adam": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 1000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false + } + } + }, + "model_params": { + "__meta__": { + "class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams" + }, + "encoder_config": { + "text_config": { + "model_type": "qwen2", + "name_or_path": "Qwen/Qwen2-7B", + "attn_implementation": "sdpa", + "trust_remote_code": true, + "token": null + }, + "numr_config": null, + "text_pooling_type": "last", + "head_size_list": [128] + }, + "model_modifier": { + "modifier_type": "peft", + "config_type": "LoraConfig" , + "config": { + "r": 16, + "lora_alpha": 32, + "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + "modules_to_save": ["head_layers", "scorer"], + "lora_dropout": 0.1 + } + }, + "positive_passage_no_shuffle": false, + "negative_passage_no_shuffle": false, + "rerank_max_len": 196, + "query_prefix": "query: ", + "passage_prefix": "document: ", + "inp_id_col": "inp_id", + "lbl_idxs_col": "ret_idxs", + "score_col": "rel", + "keyword_col_name": "keywords", + "content_col_names": ["title", "contents"], + "append_eos_token": false, + "pad_to_multiple_of": 16 + } +} diff --git a/examples/msmarco-rankllama/parquet_to_trec_eval.py b/examples/msmarco-rankllama/parquet_to_trec_eval.py new file mode 100644 index 0000000..82c9a88 --- /dev/null +++ b/examples/msmarco-rankllama/parquet_to_trec_eval.py @@ -0,0 +1,36 @@ + +import argparse +import os +import pandas as pd + + +def main(args): + """ + Combine all results from the results folder and write them to the output file. + """ + result_files = [ + os.path.join(args.input_parquet_path, x) + for x in os.listdir(args.input_parquet_path) + ] + all_results = pd.read_parquet(result_files[0]) + for f in result_files[1:]: + all_results = pd.concat([all_results, pd.read_parquet(f)]) + # sort all results by 'inp_id' and then 'score' in descending order + all_results = all_results.sort_values(by=['inp_id', 'score'], ascending=[True, False]) + + cur_inp_id = None + with open(args.output_trec_path, "w") as fout: + for row in all_results.itertuples(): + if cur_inp_id != row.inp_id: + cur_inp_id = row.inp_id + rank = 0 + rank += 1 + fout.write(f"{row.inp_id} Q0 {row.lbl_id} {rank} {row.score} dense\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input-parquet-path", type=str, required=True) + parser.add_argument("-o", "--output-trec-path", type=str, required=True) + args = parser.parse_args() + main(args) diff --git a/pecos/xmr/reranker/README.md b/pecos/xmr/reranker/README.md index f2a11d6..a60387f 100644 --- a/pecos/xmr/reranker/README.md +++ b/pecos/xmr/reranker/README.md @@ -4,22 +4,14 @@ This is a reranker for the PECOS XMR model. It is based on huggingface's transfo single process and distributed mode. It is based on the paper [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319). ## How to run -### Single process -To run the reranker in single process mode, you can use the following command: -```bash -python -m pecos.xmr.reranker.train --config_json_path -``` - -### Distributed mode -To run the reranker in distributed mode, you can use the following command to initialize the distributed configuration: -```bash -accelerate config -``` +### Training +To train the reranker, we suggest to use the `torchrun` command: -Then you can run the reranker using the following command: ```bash -accelerate launch -m pecos.xmr.reranker.train --config_json_path +torchrun --nnodes 1 --nproc-per-node 8 \ + -m pecos.xmr.reranker.train \ + --config_json_path ``` ### Predictions @@ -28,112 +20,12 @@ To run the reranker in prediction mode, you can use the following command: python -m pecos.xmr.reranker.predict --config_json_path ``` -## Configuration file - -### Training -Here is an example of the configuration file for training: -```json -{ - "train_params": { - "__meta__": { - "class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams" - }, - "target_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/target", - "input_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/input", - "label_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/label", - "training_args": { - "__meta__": { - "class_fullname": "pecos.xmr.reranker.trainer###RankLlamaTrainer.TrainingArgs" - }, - "learning_rate": 1e-4, - "output_dir": "./ds_model", - "per_device_train_batch_size": 8, - "gradient_accumulation_steps": 8, - "max_steps": -1, - "logging_strategy": "steps", - "logging_first_step": false, - "logging_steps": 10, - "save_strategy": "steps", - "save_steps": 50, - "save_total_limit": 5, - "seed": 42, - "data_seed": 42, - "bf16": true, - "dataloader_num_workers": 2, - "dataloader_prefetch_factor": 10, - "gradient_checkpointing": true, - "train_group_size": 16 - } - }, - "model_params": { - "__meta__": { - "class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams" - }, - "encoder_args": { - "__meta__": { - "class_fullname": "pecos.xmr.reranker.model###CrossEncoder.Config" - }, - "model_shortcut": "meta-llama/Llama-2-7b-hf", - "model_init_kwargs": {}, - "model_modifier": { - "modifier_type": "peft", - "config_type": "LoraConfig" , - "config": { - "r": 8, - "lora_alpha": 64, - "target_modules": ["q_proj", "v_proj"], - "modules_to_save": ["score", "classifier"], - "lora_dropout": 0.1 - } - } - }, - "positive_passage_no_shuffle": false, - "negative_passage_no_shuffle": false, - "rerank_max_len": 196, - "query_prefix": "query: ", - "passage_prefix": "document: ", - "inp_id_col": "inp_id", - "lbl_idxs_col": "ret_idxs", - "score_col": "rel", - "keyword_col_name": "keywords", - "content_col_names": ["title", "contents"], - "append_eos_token": false, - "pad_to_multiple_of": 16 - } -} -``` - -### Prediction -Following is the example of the configuration file for prediction: -```json -{ - "model_name_or_path": "/tmp/pecosdev/ds_model", - "target_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/target", - "input_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/input", - "label_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/label", - "output_dir": "/tmp/xmrout", - "per_device_eval_batch_size": 512, - "dataloader_num_workers": 1, - "dataloader_prefetch_factor": 10, - "rerank_max_len": 196, - "query_prefix": "query: ", - "passage_prefix": "document: ", - "inp_id_col": "inp_id", - "lbl_id_col": "lbl_id", - "keyword_col_name": "keywords", - "content_col_names": ["title", "contents"], - "append_eos_token": false, - "pad_to_multiple_of": 8, - "device": "cuda", - "model_init_kwargs": { - "device_map": "auto" - } -} -``` +## Config JSON Files +See example training/predict JSON files in `pecos/examples/msmarco-rankllama` folders. ## Data Schema -The column names for the data schema are configurable through the json configuration file. Following -are the various schemas that are supported by the reranker: +The column names for the data schema are configurable through the json configuration file. +Following are the various schemas that are supported by the reranker: (1) Learning Target Schema ``` diff --git a/pecos/xmr/reranker/data_utils.py b/pecos/xmr/reranker/data_utils.py index 2ec652a..6601b02 100644 --- a/pecos/xmr/reranker/data_utils.py +++ b/pecos/xmr/reranker/data_utils.py @@ -1,3 +1,4 @@ +import bisect import os import random from collections import OrderedDict @@ -10,6 +11,140 @@ import pecos +def get_pairwise_batch( + batch, + inp_col, + lbl_col, + rel_col, + group_size, + pos_label_sampling="weighted", +): + """ + Convert listwise batch into pairwise batch by sampling 1 pos_lbl/neg_lbl per input + + Example: + batch = { + "inp_col": ["q1", "q2"], + "lbl_col": [ + ["p11", "p12", "p13"], + ["p21", "p22", "p23", "p24"], + ], + "rel_col": [ + [1.0, 0.5, 0.1], + [0.9, 0.9, 0.7, 0.6], + ], + } + pairwise_batch = { + "inp_col": ["q1", "q1", "q2", "q2"], + "lbl_col": ["a11", "a13", "a22", "a23"], + "rel_col": [1.0, 0.1, 0.9, 0.7], + } + + """ + + def sample_unequal_pair_from_array(data): + "sample a pair of indices from a decent sorted array" + non_min_sum = len(data) - 1 + while data[non_min_sum] == data[-1]: + non_min_sum -= 1 + # assert non_min_sum > 0, (non_min_sum, data) + + if pos_label_sampling == "weighted": + psum = np.sum(data[: (non_min_sum + 1)]) + pos_idx = np.random.choice(range(non_min_sum + 1), p=data[: (non_min_sum + 1)] / psum) + elif pos_label_sampling == "uniform": + pos_idx = np.random.randint(non_min_sum + 1) + else: + raise NotImplementedError(f"Unknown positive sampling method {pos_label_sampling}") + + # neg_num = np.sum(data < data[pos_idx]) + neg_start = pos_idx + while data[neg_start] == data[pos_idx]: + neg_start += 1 + + # assert neg_num > 0, (neg_num, data) + neg_idx = np.random.randint(neg_start, len(data)) + return pos_idx, neg_idx + + assert group_size == 2 + + pair_indices = [sample_unequal_pair_from_array(scores) for scores in batch[rel_col]] + pos_lbls = [alist[pa[0]] for alist, pa in zip(batch[lbl_col], pair_indices)] + neg_lbls = [alist[pa[1]] for alist, pa in zip(batch[lbl_col], pair_indices)] + pos_rels = [slist[pa[0]] for slist, pa in zip(batch[rel_col], pair_indices)] + neg_rels = [slist[pa[1]] for slist, pa in zip(batch[rel_col], pair_indices)] + + return { + inp_col: np.repeat(batch[inp_col], group_size), + lbl_col: np.vstack([pos_lbls, neg_lbls]).T.flatten(), + rel_col: np.vstack([pos_rels, neg_rels]).T.flatten(), + } + + +def get_listwise_batch( + batch, + inp_col, + lbl_col, + rel_col, + group_size, + pos_label_sampling="weighted", + neg_rel_val=0.0, +): + """ + Convert listwise batch into sampled listwise batch + + Example: + batch = { + "inp_col": ["q1", "q2"], + "lbl_col": [ + ["p11", "p12", "p13", "p14"], + ["p21", "p22", "a23", "p24", "p25"], + ], + "rel_col": [ + [0.8, 0.5, 0.0, 0.0], + [0.9, 0.0, 0.0, 0.0, 0.0], + ], + } + + sampled listwise batch (group_size=3) + listwise_batch = { + "inp_col": ["q1", "q1", "q1", "q2", "q2", "q2"], + "lbl_col": ["p12", "p14", "p13", "p21", "q23", "q25"], + "rel_col": [0.5, 0.0, 0.0, 0.9, 0.0, 0.0], + } + """ + assert group_size >= 2 + + all_lbl_list, all_rel_list = [], [] + for lbl_arr, rel_arr in zip(batch[lbl_col], batch[rel_col]): + # note that bisect assumes lbl_arr/rel_arr are sorted ascendingly (by values in rel_arr) + # so we add negative sign to flip the order from descendingly to ascendingly, + # and find rightmost index (i.e., pos_ptr) whose value less than -neg_label_val. + # see https://docs.python.org/3/library/bisect.html + pos_ptr = bisect.bisect_left(-rel_arr, -neg_rel_val) + + # sample 1 positive + indices = np.random.randint(0, high=pos_ptr, size=1).tolist() + + # smaple group_size - 1 negatives + num_true_neg = min(len(lbl_arr) - pos_ptr, group_size - 1) + if num_true_neg > 0: + indices += np.random.randint(pos_ptr, high=len(lbl_arr), size=num_true_neg).tolist() + num_rand_neg = (group_size - 1) - num_true_neg + if num_rand_neg > 0: + assert NotImplementedError(f"within batch negative not support for Listwise Ranking!") + # indices += np.random.randint(0, high=lbl_space, size=num_rand_neg).tolist() + + all_lbl_list.append(lbl_arr[indices].tolist()) + all_rel_list.append(rel_arr[indices].tolist()) + # end for loop + return { + inp_col: np.repeat(batch[inp_col], group_size), + lbl_col: np.vstack(all_lbl_list).flatten(), + rel_col: np.vstack(all_rel_list).flatten(), + } + + class RankingDataUtils(pecos.BaseClass): """ Utility class for handling data related tasks @@ -59,7 +194,7 @@ def _create_sample( ret_idxs: List[int], scores: List[float], table_stores, - train_group_size: int, + group_size: int, inp_prefix: str, passage_prefix: str, keyword_col_name: str, @@ -73,7 +208,7 @@ def _create_sample( ret_idxs: The retrieved indices scores: Scores for the retrieved indices table_stores: Dictionary of table stores for input and label data - train_group_size: The number of passages used to train for each query + group_size: The number of passages used to train for each query inp_prefix: The input prefix passage_prefix: The passage prefix keyword_col_name: The column name for the query text @@ -97,17 +232,15 @@ def _create_sample( random.shuffle(pos_idxs) random.shuffle(neg_idxs) - num_positives = train_group_size // 2 + num_positives = group_size // 2 all_selections = pos_idxs[:num_positives] num_positives = len(all_selections) - num_negatives = train_group_size - num_positives + num_negatives = group_size - num_positives all_selections.extend(neg_idxs[:num_negatives]) - if len(all_selections) < train_group_size: - all_selections.extend( - random.choices(neg_idxs, k=train_group_size - len(all_selections)) - ) + if len(all_selections) < group_size: + all_selections.extend(random.choices(neg_idxs, k=group_size - len(all_selections))) all_scores = [s for s, _ in all_selections] all_pids = [pid for _, pid in all_selections] diff --git a/pecos/xmr/reranker/model.py b/pecos/xmr/reranker/model.py index 6e9212d..79a692f 100644 --- a/pecos/xmr/reranker/model.py +++ b/pecos/xmr/reranker/model.py @@ -1,3 +1,4 @@ +import copy import dataclasses as dc import json import logging @@ -6,25 +7,77 @@ from functools import partial from typing import Dict, List, Tuple, Any, Optional, Union -import peft import torch -from datasets import IterableDataset, Dataset +from torch import nn, Tensor from torch.utils.data import DataLoader -from peft import AutoPeftModelForSequenceClassification, get_peft_model + +import peft +from datasets import IterableDataset, Dataset +from peft import get_peft_model from peft.config import PeftConfig from peft.mixed_model import PeftMixedModel from peft.peft_model import PeftModel -from transformers import AutoModelForSequenceClassification, PreTrainedModel -from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig +from transformers import ( + AutoTokenizer, + AutoModel, + CONFIG_MAPPING, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, +) +from transformers.utils import ModelOutput import pecos -from pecos.xmr.reranker.trainer import RankLlamaTrainer, PARAM_FILENAME -from .data_utils import RankingDataUtils +from pecos.xmr.reranker.trainer import ( + RankingTrainer, + PARAM_FILENAME, +) +from pecos.xmr.reranker.data_utils import RankingDataUtils + logger = logging.getLogger(__name__) -class CrossEncoderConfig(PretrainedConfig): +ACT_FCT_DICT = { + "identity": nn.Identity, + "tanh": nn.Tanh, + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "elu": nn.ELU, + "gelu": nn.GELU, + "leaky_relu": nn.LeakyReLU, +} + + +@dc.dataclass +class RerankerOutput(ModelOutput): + text_emb: Optional[Tensor] = None + numr_emb: Optional[Tensor] = None + scores: Optional[Tensor] = None + + +class NumrMLPEncoderConfig(PretrainedConfig): + + model_type = "numr_mlp_encoder" + + def __init__( + self, + inp_feat_dim: int = 1, + inp_dropout_prob: float = 0.1, + hid_dropout_prob: float = 0.1, + hid_actv_type: str = "gelu", + hid_size_list: list = [64, 128, 256], + **kwargs, + ): + self.inp_feat_dim = inp_feat_dim + self.hid_size_list = hid_size_list + self.hid_actv_type = hid_actv_type + self.inp_dropout_prob = inp_dropout_prob + self.hid_dropout_prob = hid_dropout_prob + super().__init__(**kwargs) + + +class TextNumrEncoderConfig(PretrainedConfig): """ The configuration class for the cross encoder model. This class contains the model shortcut, model modifier and model initialization arguments for the model. The model shortcut is the name of the huggingface model. The @@ -32,143 +85,218 @@ class CrossEncoderConfig(PretrainedConfig): for the model. """ - model_type = "reranker_crossencoder" + default_text_model_type = "xlm-roberta" + model_type = "text_numr_crossencoder" def __init__( self, - model_shortcut: str = "", - model_modifier: Dict = {}, - model_init_kwargs: dict = {}, + text_config=None, + numr_config=None, + text_pooling_type="cls", + head_actv_type="gelu", + head_dropout_prob=0.1, + head_size_list=[128, 64], **kwargs, ): """ Initialize the cross encoder configuration - Args: - model_shortcut: The model shortcut for the huggingface model - model_modifier: The model modifier configuration (e.g. PEFT) - model_init_kwargs: The model initialization arguments. These are the arguments for the huggingface model """ + + if text_config is None: + pass + elif isinstance(text_config, PretrainedConfig): + text_config = copy.deepcopy(text_config) + elif isinstance(text_config, dict): + text_model_type = text_config.get("model_type", self.default_text_model_type) + text_config = CONFIG_MAPPING[text_model_type](**text_config) + else: + raise TypeError(f"Type(text_config) is not valid, got {type(text_config)}!") + self.text_config = text_config + self.text_pooling_type = text_pooling_type + + if numr_config is None: + pass + elif isinstance(numr_config, PretrainedConfig): + numr_config = copy.deepcopy(numr_config) + elif isinstance(numr_config, dict): + numr_config = NumrMLPEncoderConfig(**numr_config) + else: + raise TypeError(f"Type(numr_config) is not valid, got {type(numr_config)}!") + self.numr_config = numr_config + + self.head_size_list = head_size_list + self.head_actv_type = head_actv_type + self.head_dropout_prob = head_dropout_prob + super().__init__(**kwargs) - self.model_shortcut = model_shortcut - self.model_modifier = model_modifier - self.model_init_kwargs = model_init_kwargs +class MLPBlock(nn.Module): + def __init__(self, inp_size, dropout_prob, actv_type, hid_size_list): + super(MLPBlock, self).__init__() -class CrossEncoder(PreTrainedModel): - """ - The cross encoder model for ranking tasks (retrieval-based). This model is used for training and evaluation. - It is a wrapper around the huggingface transformer model. - """ + cur_inp_size = inp_size + self.mlp_layers = nn.ModuleList() + for cur_hid_size in hid_size_list: + self.mlp_layers.append(nn.Linear(cur_inp_size, cur_hid_size, bias=True)) + self.mlp_layers.append(ACT_FCT_DICT[actv_type]()) + self.mlp_layers.append(nn.Dropout(dropout_prob)) + cur_inp_size = cur_hid_size - TRANSFORMER_CLS = AutoModelForSequenceClassification - TRANSFORMER_PEFT_CLS = AutoPeftModelForSequenceClassification + def forward(self, x): + for cur_layer in self.mlp_layers: + x = cur_layer(x) + return x - @dataclass - class Config(pecos.BaseParams): - """Encoder configuration - model_shortcut (str): the model shortcut of the HuggingFace model - model_init_kwargs (dict): model initialization kwargs - model_modifier (dict): model modifier configuration - """ - model_shortcut: str = "" - model_init_kwargs: dict = dc.field(default_factory=lambda: dict()) - model_modifier: dict = dc.field(default_factory=lambda: dict()) +class NumrMLPEncoder(PreTrainedModel): + + config_class = NumrMLPEncoderConfig + + def __init__(self, config: NumrMLPEncoderConfig): + super().__init__(config) + + self.inp_dropout = nn.Dropout(config.inp_dropout_prob) + self.mlp_block = MLPBlock( + config.inp_feat_dim, + config.hid_dropout_prob, + config.hid_actv_type, + config.hid_size_list, + ) + self.layer_norm = nn.LayerNorm(config.hid_size_list[-1]) + + def forward(self, numeric_inputs): + numr_emb = self.inp_dropout(numeric_inputs) + numr_emb = self.mlp_block(numr_emb) + return self.layer_norm(numr_emb) - config_class = CrossEncoderConfig - def __init__(self, config: CrossEncoderConfig): +class TextNumrEncoder(PreTrainedModel): + + config_class = TextNumrEncoderConfig + + def __init__(self, config: TextNumrEncoderConfig): """ Initialize the cross encoder model Args: config: The configuration for the cross encoder """ + + # sanity check + if config.text_pooling_type not in ["cls", "avg", "last"]: + raise NotImplementedError( + f"text_pooling_type={config.text_pooling_type} is not support!" + ) + if config.text_config is None and config.numr_config is None: + raise ValueError(f"text_config and numr_config can not be None at the same time!") super().__init__(config) - base_model = AutoModelForSequenceClassification.from_pretrained( - config.model_shortcut, num_labels=1, **config.model_init_kwargs - ) - base_model.config.pad_token_id = ( - 0 if base_model.config.pad_token_id is None else base_model.config.pad_token_id + + # text encoder + if config.text_config: + text_encoder = AutoModel.from_pretrained( + config.text_config._name_or_path, + attn_implementation=config.text_config._attn_implementation, + trust_remote_code=config.text_config.trust_remote_code, + token=getattr(config.text_config, "token", None), + ) + text_encoder.config.pad_token_id = ( + 0 if text_encoder.config.pad_token_id is None else text_encoder.config.pad_token_id + ) + self.text_encoder = text_encoder + self.text_emb_dim = self.text_encoder.config.hidden_size + self.text_pooling_type = config.text_pooling_type + else: + self.text_encoder = None # type: ignore + self.text_emb_dim = 0 + + # numeric encoder + if config.numr_config: + self.numr_encoder = NumrMLPEncoder(config.numr_config) + self.numr_emb_dim = self.numr_encoder.config.hid_size_list[-1] + else: + self.numr_encoder = None # type: ignore + self.numr_emb_dim = 0 + + # head layer + cur_feat_dim = self.text_emb_dim + self.numr_emb_dim + self.head_layers = MLPBlock( + cur_feat_dim, + config.head_dropout_prob, + config.head_actv_type, + config.head_size_list, ) - self.hf_model = base_model + self.scorer = nn.Linear(config.head_size_list[-1], 1, bias=True) - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - *model_args, - config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, - cache_dir: Optional[Union[str, os.PathLike]] = None, - ignore_mismatched_sizes: bool = False, - force_download: bool = False, - local_files_only: bool = False, - token: Optional[Union[str, bool]] = None, - revision: str = "main", - use_safetensors: Optional[bool] = None, - **kwargs, + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + numeric_inputs=None, ): """ - Load the model from the pretrained model name or path. Override the `from_pretrained` method of the - `PreTrainedModel` class. + Returns the forward output of the huggingface model """ - is_local = os.path.isdir(pretrained_model_name_or_path) - param_folder = pretrained_model_name_or_path - - def super_return(): - return PreTrainedModel.from_pretrained( - pretrained_model_name_or_path, - *model_args, - config, - cache_dir, - ignore_mismatched_sizes, - force_download, - local_files_only, - token, - revision, - use_safetensors, - **kwargs, - ) - - if not is_local: - raise NotImplementedError(f"{cls} can only load local models") - - with open(os.path.join(param_folder, PARAM_FILENAME), "r") as param_file: - params = json.load(param_file) - - xe_config = CrossEncoder.Config.from_dict(params["model_params"]["encoder_args"]) - xe_config = CrossEncoderConfig(**xe_config.to_dict()) - for k, v in kwargs.items(): - xe_config.model_init_kwargs[k] = v - model = CrossEncoder(xe_config) - try: - if xe_config.model_modifier["modifier_type"] == "peft": - model = PeftModel.from_pretrained(model, param_folder) + # get text embedding from HF Pretrained Transformers encoder + text_emb = None + if self.text_encoder: + text_input_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + if token_type_ids: + text_input_dict["token_type_ids"] = token_type_ids + text_outputs = self.text_encoder(**text_input_dict, return_dict=True) + if hasattr(text_outputs, "pooler_output"): + text_emb = text_outputs.pooler_output else: - super_return() - except KeyError: - logger.info("No peft configuration found") - - return model + text_emb = self.text_pooler(text_outputs.last_hidden_state, attention_mask) + + # get numr embedding from Numerical MLP Encoder + numr_emb = None + if self.numr_encoder: + numr_emb = self.numr_encoder(numeric_inputs) + + # head layer + scorer + if self.text_encoder and self.numr_encoder: + head_emb = torch.cat((text_emb, numr_emb), 1) + elif self.text_encoder is not None: + head_emb = text_emb + elif self.numr_encoder is not None: + head_emb = numr_emb + head_emb = self.head_layers(head_emb) + scores = self.scorer(head_emb) + + return RerankerOutput( + text_emb=text_emb, + numr_emb=numr_emb, + scores=scores, + ) - def forward(self, *args, **kwargs): - """ - Returns the forward output of the huggingface model - """ - return self.hf_model(*args, **kwargs) + def text_pooler(self, last_hidden_states, attention_mask): + if self.text_pooling_type == "cls": + text_emb = last_hidden_states[:, 0, :] + elif self.text_pooling_type == "avg": + # https://huggingface.co/intfloat/multilingual-e5-base + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + text_emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif self.text_pooling_type == "last": + # https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + text_emb = last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + text_emb = last_hidden_states[ + torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths + ] + return text_emb def gradient_checkpointing_enable(self, **kwargs): """ Enable gradient checkpointing for the model """ - try: - if self.config.model_modifier["modifier_type"] == "peft": - self.hf_model.enable_input_require_grads() - except KeyError: - pass - self.hf_model.gradient_checkpointing_enable(**kwargs) + self.text_encoder.gradient_checkpointing_enable(**kwargs) class RankingModel(pecos.BaseClass): @@ -183,13 +311,13 @@ class TrainParams(pecos.BaseParams): """ The training parameters for the ranking model. Args: - training_args (RankLlamaTrainer.TrainingArgs): The training arguments for the model + hf_trainer_args (RankingTrainer.TrainingArgs): The training arguments for the model target_data_folder (str): The path to the target data folder input_data_folder (str): The path to the input data folder label_data_folder (str): The path to the label data folder """ - training_args: RankLlamaTrainer.TrainingArgs + hf_trainer_args: RankingTrainer.TrainingArgs target_data_folder: str = field( metadata={ "help": "Path to folder containing target parquet files (inp_id, [lbl_id], [rel_val])" @@ -210,7 +338,8 @@ class ModelParams(pecos.BaseParams): The parameters for the ranking model. This class contains the data, encoder and training arguments for the model. """ - encoder_args: CrossEncoder.Config + encoder_config: TextNumrEncoderConfig = None # type: ignore + model_modifier: dict = dc.field(default_factory=lambda: dict()) positive_passage_no_shuffle: bool = False negative_passage_no_shuffle: bool = False @@ -232,10 +361,10 @@ class EvalParams(pecos.BaseParams): Evaluation parameters """ - model_name_or_path: str target_data_folder: str input_data_folder: str label_data_folder: str + model_path: str output_dir: str output_file_prefix: str = "output_" output_file_suffix: str = "" @@ -247,18 +376,18 @@ class EvalParams(pecos.BaseParams): passage_prefix: str = "document: " inp_id_col: str = "inp_id" lbl_id_col: str = "lbl_id" + inp_id_orig_col: Optional[str] = None + lbl_id_orig_col: Optional[str] = None keyword_col_name: str = "keywords" content_col_names: List[str] = field(default_factory=lambda: ["title", "contents"]) content_sep: str = " " append_eos_token: bool = False pad_to_multiple_of: int = 16 bf16: bool = True - device: str = "cuda" - model_init_kwargs: dict = dc.field(default_factory=lambda: dict()) def __init__( self, - encoder: Union[CrossEncoder, PeftModel, PeftMixedModel], + encoder: Union[PreTrainedModel, PeftModel, PeftMixedModel], tokenizer: AutoTokenizer, model_params: ModelParams, train_params: Optional[TrainParams] = None, @@ -266,38 +395,34 @@ def __init__( """ Initialize the ranking model. The model contains the encoder, tokenizer, model parameters and training parameters. Args: - encoder (Union[CrossEncoder, PeftModel, PeftMixedModel]): The encoder model + encoder (Union[PreTrainedModel, PeftModel, PeftMixedModel]): The encoder model tokenizer (AutoTokenizer): The tokenizer for the model model_params (RankingModel.ModelParams): The model parameters train_params (Optional[RankingModel.TrainParams]): The training parameters """ self.tokenizer = tokenizer - self.cross_encoder = encoder - + self.encoder = encoder self.model_params = self.ModelParams.from_dict(model_params) self.train_params = self.TrainParams.from_dict(train_params) if train_params else None @classmethod - def get_modified_model(cls, model: CrossEncoder, mod_config: Dict): + def get_modified_model(cls, model: PreTrainedModel, config: Dict): """ Takes a pretrained Huggingface model and modifies it to include new features. Currently, the `modifier_type` supported by this method is limited to the `peft` package. Args: - model (CrossEncoder): A PreTrainedModel from the transformers package. - mod_config (Dict): A dictionary containing the configuration for the model modifier. + model (PreTrainedModel): A PreTrainedModel from the transformers package. + config (Dict): A dictionary containing the configuration for the model modifier. Returns: The modified model """ - if mod_config["modifier_type"] == "peft": - config_type = getattr(peft, mod_config["config_type"]) - peft_config: PeftConfig = config_type(**mod_config["config"]) - - model = get_peft_model(model, peft_config) - - return model + if config["modifier_type"] == "peft": + config_cls = getattr(peft, config["config_type"]) + peft_config: PeftConfig = config_cls(**config["config"]) + modified_model = get_peft_model(model, peft_config) + return modified_model else: - logger.warn("Using model without modifiers (e.g. LoRA)") - return model + raise NotImplementedError(f"We only support modifier_type==peft for now!") @classmethod def init_model(cls, model_params: ModelParams, train_params: TrainParams): @@ -309,35 +434,93 @@ def init_model(cls, model_params: ModelParams, train_params: TrainParams): Returns: An instance of RankingModel """ - hf_trainer_args = train_params.training_args + hf_trainer_args = train_params.hf_trainer_args if hf_trainer_args.local_rank > 0: torch.distributed.barrier() - config = model_params.encoder_args.to_dict() - config = CrossEncoderConfig(**config) - encoder = CrossEncoder( - config=config, - ) + encoder_config = TextNumrEncoderConfig(**model_params.encoder_config) + encoder = TextNumrEncoder(encoder_config) if hf_trainer_args.bf16: encoder = encoder.bfloat16() - if config.model_modifier: - encoder = cls.get_modified_model(model=encoder, mod_config=config.model_modifier) + if model_params.model_modifier: + if hf_trainer_args.gradient_checkpointing: + encoder.text_encoder.enable_input_require_grads() + encoder = cls.get_modified_model( + model=encoder, + config=model_params.model_modifier, + ) tokenizer = AutoTokenizer.from_pretrained( - model_params.encoder_args.model_shortcut, + encoder_config.text_config._name_or_path, + trust_remote_code=encoder_config.text_config.trust_remote_code, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.unk_token_id tokenizer.padding_side = "right" - if torch.distributed.is_initialized(): - if hf_trainer_args.local_rank == 0: - torch.distributed.barrier() + if hf_trainer_args.local_rank == 0: + torch.distributed.barrier() return cls(encoder, tokenizer, model_params, train_params=train_params) + @classmethod + def load(cls, load_dir): + """Load the model from the folder load_dir + + Args: + load_dir (str): path of the loading folder + """ + + param_file = os.path.join(load_dir, PARAM_FILENAME) + if not os.path.exists(param_file): + raise FileNotFoundError(f"The model {load_dir} does not exists.") + + param = json.loads(open(param_file, "r").read()) + model_params = cls.ModelParams.from_dict(param.get("model_params", None)) + + if model_params.model_modifier: + encoder_config = TextNumrEncoderConfig(**model_params.encoder_config) + encoder = TextNumrEncoder(encoder_config) + encoder = PeftModel.from_pretrained(encoder, load_dir) + encoder = encoder.merge_and_unload() + else: + encoder = TextNumrEncoder.from_pretrained(load_dir) + + tokenizer = AutoTokenizer.from_pretrained(load_dir) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.unk_token_id + tokenizer.padding_side = "right" + + return cls(encoder, tokenizer, model_params) + + def get_param_to_save(self): + param = { + "model": self.__class__.__name__, + "model_params": self.model_params.to_dict(), + "train_params": self.train_params.to_dict(), + } + param = self.append_meta(param) + return param + + def save(self, save_dir): + """Save the model to the folder save_dir + + Args: + save_dir (str): path of the saving folder + """ + + os.makedirs(save_dir, exist_ok=True) + param_file = os.path.join(save_dir, PARAM_FILENAME) + + param_to_save = self.get_param_to_save() + with open(param_file, "w", encoding="utf-8") as fout: + fout.write(json.dumps(param_to_save, indent=True)) + + self.encoder.save_pretrained(save_dir) + self.tokenizer.save_pretrained(save_dir) + @classmethod def _collate_sharded( cls, @@ -368,7 +551,7 @@ def _collate_sharded( retr_idxs, scores, table_stores, - train_params.training_args.train_group_size, + train_params.hf_trainer_args.group_size, model_params.query_prefix, model_params.passage_prefix, model_params.keyword_col_name, @@ -431,7 +614,9 @@ def _collate( # incorrectly across devices in distributed training. m_scores = torch.tensor(scores, dtype=torch.float).flatten() - return {"input": pairs_collated, "scores": m_scores} + ret_dict = dict(target=m_scores) + ret_dict.update(pairs_collated) + return ret_dict @classmethod def _collate_sharded_eval( @@ -453,11 +638,17 @@ def _collate_sharded_eval( fts = [] inp_idxs = [] lbl_idxs = [] + inp_id_orig_col = ( + eval_params.inp_id_orig_col if eval_params.inp_id_orig_col else eval_params.inp_id_col + ) + lbl_id_orig_col = ( + eval_params.lbl_id_orig_col if eval_params.lbl_id_orig_col else eval_params.lbl_id_col + ) for s in data: inp_id = s[eval_params.inp_id_col] retr_id = s[eval_params.lbl_id_col] - inp_idxs.append(inp_id) - lbl_idxs.append(retr_id) + inp_idxs.append(table_stores["input"][inp_id][inp_id_orig_col]) + lbl_idxs.append(table_stores["label"][retr_id][lbl_id_orig_col]) fts.append( RankingDataUtils._format_sample( @@ -522,11 +713,9 @@ def _collate_eval( return_tensors="pt", ) - return { - "inp_idxs": inp_idxs, - "lbl_idxs": lbl_idxs, - "inputs": pairs_collated, - } + ret_dict = dict(inp_idxs=inp_idxs, lbl_idxs=lbl_idxs) + ret_dict.update(pairs_collated) + return ret_dict @classmethod def predict( @@ -534,10 +723,15 @@ def predict( eval_dataset: IterableDataset, table_stores: Dict[str, Dataset], eval_params: EvalParams, + encoder: PreTrainedModel, tokenizer: AutoTokenizer, - model: CrossEncoder, ): - model.eval() + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Use pytorch device: {}".format(device)) + device = torch.device(device) # type: ignore + if torch.cuda.device_count() > 1 and not isinstance(encoder, torch.nn.DataParallel): + encoder = torch.nn.DataParallel(encoder) + encoder = encoder.to(device) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = 0 @@ -552,16 +746,29 @@ def predict( # To ensure efficiency we prefetch samples in parallel prefetch_factor=eval_params.dataloader_prefetch_factor, collate_fn=partial(cls._collate_sharded_eval, tokenizer, eval_params, table_stores), + shuffle=False, ) + encoder.eval() all_results = [] - for batch in eval_dataloader: - with torch.inference_mode(): + with torch.inference_mode(): + for batch in eval_dataloader: inp_ids = batch["inp_idxs"] lbl_ids = batch["lbl_idxs"] - inputs = batch["inputs"].to(eval_params.device) - model_output = model(**inputs).logits - scores = model_output.cpu().detach().float().numpy() + + # place inputs to the device + for k in batch.keys(): + if torch.is_tensor(batch[k]): + batch[k] = batch[k].to(device) + + # forward + output = encoder( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + numeric_inputs=batch.get("numeric_inputs", None), + token_type_ids=batch.get("token_type_ids", None), + ).scores + scores = output.cpu().detach().float().numpy() for i in range(len(scores)): inp_id = inp_ids[i] ret_id = lbl_ids[i] @@ -586,42 +793,29 @@ def train( model_params: The model parameters train_params: The training parameters """ - training_args = train_params.training_args + hf_trainer_args = train_params.hf_trainer_args # we need to have 'unused' columns to maintain information about # group and scores coming from the collator - training_args.remove_unused_columns = False - outer_model = cls.init_model(model_params, train_params) - inner_model = outer_model.cross_encoder - - logger.info("Model loading...") - if torch.distributed.is_initialized(): - torch.distributed.barrier() - else: - # NOTE This is needed for the case where the program is run in a single process mode - if training_args.bf16 and not torch.distributed.is_initialized(): - inner_model = inner_model.bfloat16() - - logger.info("=" * 50) - logger.info( - f"Memory used by model: {round(inner_model.get_memory_footprint() / 1024 / 1024 / 1024, 2)} GB" - ) - - trainer = RankLlamaTrainer( - model=inner_model, - args=training_args, + hf_trainer_args.remove_unused_columns = False + model = cls.init_model(model_params, train_params) + param_to_save = model.get_param_to_save() + + trainer = RankingTrainer( + model=model.encoder, + args=hf_trainer_args, + tokenizer=model.tokenizer, train_dataset=train_dataset, data_collator=partial( cls._collate_sharded, - outer_model.tokenizer, + model.tokenizer, model_params, train_params, table_stores, ), - outer_model=outer_model, + param_to_save=param_to_save, ) - # NOTE: in the huggingface trainers `_prepare_input` method, the inputs are converted from - # mps device to cpu. To run on Apple Silicon, the method should be overridden. It is not - # clear if training is supported for Apple Silicon devices. - trainer.train() - trainer.save_model() + trainer.train( + resume_from_checkpoint=hf_trainer_args.resume_from_checkpoint, + ) + return model diff --git a/pecos/xmr/reranker/predict.py b/pecos/xmr/reranker/predict.py index 89de357..6906c86 100644 --- a/pecos/xmr/reranker/predict.py +++ b/pecos/xmr/reranker/predict.py @@ -2,9 +2,8 @@ import argparse import os from datasets import load_dataset -from .data_utils import RankingDataUtils -from .model import RankingModel, CrossEncoder -from transformers import AutoTokenizer +from pecos.xmr.reranker.data_utils import RankingDataUtils +from pecos.xmr.reranker.model import RankingModel from tqdm import tqdm import pandas as pd @@ -38,11 +37,12 @@ def construct_file_list(folder): if not os.path.exists(params.output_dir): os.makedirs(params.output_dir) - model = CrossEncoder.from_pretrained(params.model_name_or_path, **params.model_init_kwargs) + outer_model = RankingModel.load(params.model_path) + inner_model = outer_model.encoder if params.bf16: - model = model.bfloat16() + inner_model = inner_model.bfloat16() - tokenizer = AutoTokenizer.from_pretrained(model.config.model_shortcut) + tokenizer = outer_model.tokenizer for target_file in tqdm(target_files): target_filename = os.path.basename(target_file) @@ -60,8 +60,8 @@ def construct_file_list(folder): eval_dataset, table_stores, params, + encoder=inner_model, tokenizer=tokenizer, - model=model, ) # Save the results to a parquet with (inp_id, lbl_id, score) columns diff --git a/pecos/xmr/reranker/train.py b/pecos/xmr/reranker/train.py index aaa1f01..b77d9d7 100644 --- a/pecos/xmr/reranker/train.py +++ b/pecos/xmr/reranker/train.py @@ -31,15 +31,14 @@ def main(config_json_path: str): param = json.load(fin) model_params: RankingModel.ModelParams = RankingModel.ModelParams.from_dict( param.get("model_params", None), - recursive=True, + recursive=False, ) - train_params: RankingModel.TrainParams = RankingModel.TrainParams.from_dict( param.get("train_params", None), recursive=True, ) - set_seed(train_params.training_args.seed) + set_seed(train_params.hf_trainer_args.seed) # helper function for getting the list of filepaths in a folder def construct_file_list(folder): @@ -47,44 +46,47 @@ def construct_file_list(folder): input_files = construct_file_list(train_params.input_data_folder) label_files = construct_file_list(train_params.label_data_folder) - input_files, label_files = RankingDataUtils.get_sorted_data_files( - input_files, "inp_id" - ), RankingDataUtils.get_sorted_data_files(label_files, "lbl_id") - + input_files = RankingDataUtils.get_sorted_data_files(input_files, "inp_id") + label_files = RankingDataUtils.get_sorted_data_files(label_files, "lbl_id") train_dataset = load_dataset( "parquet", data_dir=train_params.target_data_folder, streaming=True, split="train" ) - train_dataset_rows = RankingDataUtils.get_parquet_rows(train_params.target_data_folder) - logger.info(f"total target inputs: {train_dataset_rows}") - training_args = train_params.training_args + hf_trainer_args = train_params.hf_trainer_args # set the max_steps in accordance with the number of num_rows - if training_args.max_steps <= 0: - ws = training_args.world_size - bs = training_args.per_device_train_batch_size - gas = training_args.gradient_accumulation_steps + if hf_trainer_args.max_steps <= 0: + train_dataset_rows = RankingDataUtils.get_parquet_rows(train_params.target_data_folder) + logger.info(f"total target inputs: {train_dataset_rows}") + ws = hf_trainer_args.world_size + bs = hf_trainer_args.per_device_train_batch_size + gas = hf_trainer_args.gradient_accumulation_steps batch_size = ws * bs * gas max_steps = train_dataset_rows // batch_size - training_args.max_steps = max_steps + hf_trainer_args.max_steps = max_steps logger.info(f"total batch size: {batch_size}, train steps: {max_steps}") else: - logger.info(f"max steps: {training_args.max_steps}") + logger.info(f"max steps: {hf_trainer_args.max_steps}") table_stores = { "input": load_dataset("parquet", data_files=input_files, split="train"), "label": load_dataset("parquet", data_files=label_files, split="train"), } - train_dataset = train_dataset.shuffle(buffer_size=5000, seed=training_args.data_seed) + train_dataset = train_dataset.shuffle(buffer_size=5000, seed=hf_trainer_args.data_seed) train_dataset = datasets.distributed.split_dataset_by_node( - train_dataset, training_args.local_rank, training_args.world_size + train_dataset, hf_trainer_args.local_rank, hf_trainer_args.world_size ) logger.info("Waiting for main process to perform the mapping") if torch.distributed.is_initialized(): torch.distributed.barrier() - RankingModel.train(train_dataset, table_stores, model_params, train_params) + model = RankingModel.train(train_dataset, table_stores, model_params, train_params) + + # Save model only for world process 0 + if hf_trainer_args.process_index == 0: + logger.info(f"Saving model to {hf_trainer_args.output_dir}") + model.save(hf_trainer_args.output_dir) if __name__ == "__main__": diff --git a/pecos/xmr/reranker/trainer.py b/pecos/xmr/reranker/trainer.py index 9ba2809..b386a03 100644 --- a/pecos/xmr/reranker/trainer.py +++ b/pecos/xmr/reranker/trainer.py @@ -6,9 +6,18 @@ from typing import Optional, Any, Tuple, Dict import torch -from torch.utils.data import DataLoader -from transformers import Trainer, TrainingArguments, HfArgumentParser - +import torch.nn as nn +from torch.utils.data import DataLoader, IterableDataset +from transformers import ( + Trainer, + TrainingArguments, + HfArgumentParser, +) +from transformers.trainer_callback import ( + TrainerCallback, + TrainerControl, + TrainerState, +) import pecos PARAM_FILENAME: str = "param.json" @@ -16,21 +25,112 @@ logger = logging.getLogger(__name__) -class RankLlamaTrainer(Trainer, pecos.BaseClass): - """ - Trainer class for the RankLlama model. This class extends the Trainer class. - """ +class PairwisePointwiseHybridLoss(nn.Module): + def __init__(self, pairwise_loss, pointwise_loss): + super(PairwisePointwiseHybridLoss, self).__init__() + self.pairwise_loss = pairwise_loss + self.pointwise_loss = pointwise_loss + + def forward(self, preds, target, alpha=0.5): + """ + Args: + preds (torch.Tensor): prediction of shape (B, 2) + target (torch.Tensor): gt target of shape (B, 2) + target[:, 0] corresponds to relevance scores of positive labels + target[:, 1] correspodns to relevance scores of negative labels + """ + pairwise_target = torch.ones(preds.shape[0], device=preds.device).long() + loss1 = self.pairwise_loss(preds[:, 0], preds[:, 1], pairwise_target) - loss_fn = torch.nn.CrossEntropyLoss(reduction="mean") - outer_model = None + if self.pointwise_loss is not None: + loss2 = self.pointwise_loss(preds.flatten(), target.flatten()) + return alpha * loss1 + (1.0 - alpha) * loss2 + else: + return loss1 - def __init__(self, *args, **kwargs): - self.outer_model = kwargs.pop("outer_model") - super(RankLlamaTrainer, self).__init__(*args, **kwargs) + +class ListwisePointwiseHybridLoss(nn.Module): + def __init__(self, listwise_loss, pointwise_loss): + super(ListwisePointwiseHybridLoss, self).__init__() + self.listwise_loss = listwise_loss + self.pointwise_loss = pointwise_loss + + def forward(self, preds, target, alpha=0.5): + """ + Args: + preds (torch.Tensor): prediction of shape (B, M) + target (torch.Tensor): gt target of shape (B, M) + target[:, 0] corresponds to the relevance scores of positive labels + target[:, 1:] corresponds to the relevance scores of negative labels + """ + listwise_target = torch.zeros(preds.shape[0], device=preds.device).long() + loss1 = self.listwise_loss(preds, listwise_target) + + if self.pointwise_loss is not None: + loss2 = self.pointwise_loss(preds.flatten(), target.flatten()) + return alpha * loss1 + (1.0 - alpha) * loss2 + else: + return loss1 + + +LOSS_FN_DICT = { + "pairwise": PairwisePointwiseHybridLoss( + nn.MarginRankingLoss(reduction="mean", margin=0.1), + nn.MSELoss(reduction="mean"), + ), + "listwise": ListwisePointwiseHybridLoss( + nn.CrossEntropyLoss(reduction="mean"), + nn.BCEWithLogitsLoss(reduction="mean"), + ), +} + + +class LoggerCallback(TrainerCallback): + def on_epoch_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + train_dataloader, + **kwargs, + ): + if isinstance(train_dataloader.dataset, IterableDataset): + train_dataloader._IterableDataset_len_called = None + else: + pass + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs=None, + **kwargs, + ): + # avoid modifying the logs object as it is shared between callbacks + logs = copy.deepcopy(logs) + _ = logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "loss" in logs: + logs["loss"] = round(logs["loss"], 6) + if "grad_norm" in logs: + logs["grad_norm"] = round(logs["grad_norm"], 6) + if "epoch" in logs: + logs["epoch"] = round(logs["epoch"], 2) + if state.is_world_process_zero: + logger.info(logs) + + +class RankingTrainer(Trainer, pecos.BaseClass): + """ + Trainer class for the pecos.xmr.reranker.RankingModel. + """ @dataclass class TrainingArgs(TrainingArguments, pecos.BaseParams): - train_group_size: int = 8 + loss_fn: str = "listwise" + loss_alpha: float = 1.0 + group_size: int = 8 @classmethod def from_dict(cls, param=None): @@ -47,6 +147,14 @@ def to_dict(self, with_meta=True): d = super().to_dict() return self.append_meta(d) if with_meta else d + def __init__(self, *args, **kwargs): + param_to_save = kwargs.pop("param_to_save") + super(RankingTrainer, self).__init__(*args, **kwargs) + + self.loss_fn = LOSS_FN_DICT[self.args.loss_fn] + self.loss_alpha = self.args.loss_alpha + self.param_to_save = param_to_save + def get_train_dataloader(self) -> DataLoader: """ Returns the training dataloader. This function is called by the Trainer class. @@ -55,7 +163,7 @@ def get_train_dataloader(self) -> DataLoader: prefetch_factor = prefetch_factor if prefetch_factor else 10 return DataLoader( self.train_dataset, - batch_size=self.args.per_device_train_batch_size, + batch_size=self._train_batch_size, # Ensure that at least one worker is creating batches # parallel to the model compute num_workers=max(self.args.dataloader_num_workers, 1), @@ -75,35 +183,18 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if output_dir is not None: os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving to {output_dir}") - - outer_model: Any = self.outer_model super()._save(output_dir, state_dict) # save the config - param = { - "model": outer_model.__class__.__name__, - "model_params": outer_model.model_params.to_dict(), - "train_params": outer_model.train_params.to_dict(), - } - output_dir = output_dir if output_dir is not None else self.args.output_dir - - param = outer_model.append_meta(param) with open(os.path.join(output_dir, PARAM_FILENAME), "w", encoding="utf-8") as f: - f.write(json.dumps(param, indent=True)) - - def _prepare_inputs(self, inputs): - """ - Prepare the inputs for the model. This function is called by the Trainer class. Converts the inputs to mps - tensors if available. - """ - super_inputs = super(RankLlamaTrainer, self)._prepare_inputs(inputs) - if torch.backends.mps.is_available(): - super_inputs = {k: v.to("mps") for k, v in super_inputs.items()} - return super_inputs + f.write(json.dumps(self.param_to_save, indent=True)) def compute_loss( - self, model, inputs: Dict[str, Any], return_outputs: bool = False + self, + model, + inputs: Dict[str, Any], + return_outputs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute the loss for the model. This function is called by the Trainer class. @@ -112,16 +203,23 @@ def compute_loss( inputs: The inputs to the model return_outputs: Whether to return the outputs """ - self.args: RankLlamaTrainer.TrainingArgs - train_group_size = self.args.train_group_size - if not train_group_size: - raise NotImplementedError("Cannot perform ranking without train group") - gt_scores = inputs["scores"].reshape(-1, train_group_size) - ranker_logits = model(**inputs["input"], return_dict=True).logits - batch_size = gt_scores.shape[0] - - grouped_logits = ranker_logits.view(batch_size, -1) - assert grouped_logits.shape == gt_scores.shape - loss = self.loss_fn(grouped_logits, gt_scores) - - return (loss, ranker_logits) if return_outputs else loss + self.args: RankingTrainer.TrainingArgs + group_size = self.args.group_size + + # ground truth target + target = inputs["target"] + target = target.view(-1, group_size) # [B, M] + batch_size = target.shape[0] + + # model prediction scores + preds_1d = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + numeric_inputs=inputs.get("numeric_inputs", None), + token_type_ids=inputs.get("token_type_ids", None), + ).scores + preds_2d = preds_1d.view(batch_size, -1) # [B, M] + assert preds_2d.shape == target.shape + + loss = self.loss_fn(preds_2d, target, alpha=self.loss_alpha) + return (loss, preds_1d) if return_outputs else loss