Skip to content

Commit

Permalink
Adding distributed training functionality in bert embedding model and… (
Browse files Browse the repository at this point in the history
NVIDIA#9822)

* Adding distributed training functionality in bert embedding model and adding a embedding generation script

Signed-off-by: adityavavre <avavre@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: adityavavre <adityavavre@users.noreply.github.com>

* Removing unused import

Signed-off-by: adityavavre <avavre@nvidia.com>

* Removing apex dependency

Signed-off-by: adityavavre <avavre@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: adityavavre <adityavavre@users.noreply.github.com>

* Adding new arguments to bert mebedding config

Signed-off-by: adityavavre <avavre@nvidia.com>

* Adding constrastive score clamping and editing config

Signed-off-by: adityavavre <avavre@nvidia.com>

* Adding sampling feature for hard negatives in bert embedding dataset

Signed-off-by: adityavavre <avavre@nvidia.com>

* Removing unecessary seed

Signed-off-by: adityavavre <avavre@nvidia.com>

---------

Signed-off-by: adityavavre <avavre@nvidia.com>
Signed-off-by: adityavavre <adityavavre@users.noreply.github.com>
Co-authored-by: adityavavre <adityavavre@users.noreply.github.com>
Co-authored-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent 72f630d commit 558cf2e
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ model:
vocab_file: null
merge_file: null

# embedding-specific arguemnts
softmax_temp: 0.02 # softmax temp for contrastive loss
global_inbatch_negatives: True # whether to use in-batch negatives from other ranks during training
backprop_type: 'global' # whether to use `global` or `local` backpropagation during training. Refer to Flava paper for details.

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
Expand All @@ -93,7 +98,7 @@ model:
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
# These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+).
Expand Down Expand Up @@ -127,7 +132,7 @@ model:
# Path to data must be specified by the user.
data_train: null
data_validation: null
hard_negatives_to_train: 4
hard_negatives_to_train: 4 # number of hard negatives to use per example for training
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_impl: mmap
splits_string: 900,50,50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(cfg) -> None:
model_cfg = MegatronBertEmbeddingModel.merge_cfg_with(cfg.restore_from_path, cfg)

assert (
model_cfg.micro_batch_size * cfg.trainer.devices == model_cfg.global_batch_size
model_cfg.micro_batch_size * cfg.trainer.devices * cfg.trainer.num_nodes == model_cfg.global_batch_size
), "Gradiant accumulation is not supported for contrastive learning yet"

OmegaConf.set_struct(model_cfg, True)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.models.information_retrieval.megatron_bert_embedding_model import MegatronBertEmbeddingModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronBertTrainerBuilder
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf", config_name="megatron_bert_embedding_config")
def main(cfg) -> None:
if cfg.model.data.dataloader_type != "LDDL":
mp.set_start_method("spawn", force=True)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

trainer = MegatronBertTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

model_cfg = MegatronBertEmbeddingModel.merge_cfg_with(cfg.restore_from_path, cfg)

OmegaConf.set_struct(model_cfg, True)
with open_dict(model_cfg):
model_cfg.precision = trainer.precision

logging.info(f"Loading model from {cfg.restore_from_path}")
model = MegatronBertEmbeddingModel.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
save_restore_connector=NLPSaveRestoreConnector(),
override_config_path=model_cfg,
strict=True,
)

trainer.test(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from random import choices, sample
from typing import Mapping, Optional

import datasets
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(
num_hard_negatives: int = 4,
):
"""
file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format.
file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format.
tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated.
min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements.
Expand Down Expand Up @@ -132,7 +133,10 @@ def __getitem__(self, idx):
if isinstance(idx, np.uint32):
idx = idx.item()

assert idx < len(self.indexed_dataset)
if idx is not None:
assert idx < len(self.indexed_dataset)
else:
idx = -1
# idx may < 0 because we pad_samples_to_global_batch_size, e.g. id = -1
if idx < 0:
idx = len(self) + idx
Expand All @@ -159,10 +163,16 @@ def _process_example(self, example):
if self.data_type == 'train':
q = self.tokenizer.text_to_ids("query: " + example['query'].strip())
d = self.tokenizer.text_to_ids("passage: " + example['pos_doc'].strip())
nd = [
self.tokenizer.text_to_ids("passage: " + example['neg_doc'][i].strip())
for i in range(self.num_hard_negatives)
]
# handle cases where the required number of hard negatives are not present
if len(example['neg_doc']) < self.num_hard_negatives:
nd = example['neg_doc']
# sample rest with replacement
nd = nd + choices(example['neg_doc'], k=self.num_hard_negatives - len(example['neg_doc']))
else:
# sample without replacement
nd = sample(example['neg_doc'], k=self.num_hard_negatives)
assert len(nd) == self.num_hard_negatives, "Error in sampling required number of hard negatives"
nd = [self.tokenizer.text_to_ids("passage: " + ex.strip()) for ex in nd]

elif self.data_type == 'query':
q = self.tokenizer.text_to_ids("query: " + example['query'].strip())
Expand Down Expand Up @@ -292,6 +302,7 @@ def collate_fn(self, batch):
'input_ids': input_ids,
'token_type_ids': torch.zeros_like(input_ids),
'attention_mask': attention_mask,
'metadata': metadata,
}

return processed_batch
Loading

0 comments on commit 558cf2e

Please sign in to comment.