From 228ca1c82ece27052eac1e1fc4de11ba5b7f654f Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 24 Sep 2024 13:45:51 +0200 Subject: [PATCH] fix training tensor parallel --- eole/bin/run/train.py | 3 +- eole/inference_engine.py | 3 +- eole/modules/multi_headed_attn.py | 6 +- eole/modules/transformer_mlp.py | 4 +- eole/utils/distributed.py | 153 +----------------- eole/utils/distributed_workers.py | 153 ++++++++++++++++++ .../llama-instruct-inference.yaml | 6 +- 7 files changed, 169 insertions(+), 159 deletions(-) create mode 100644 eole/utils/distributed_workers.py diff --git a/eole/bin/run/train.py b/eole/bin/run/train.py index c19fa78a..d618773a 100644 --- a/eole/bin/run/train.py +++ b/eole/bin/run/train.py @@ -2,7 +2,8 @@ """Train models with dynamic data.""" import torch from functools import partial -from eole.utils.distributed import ErrorHandler, spawned_train +from eole.utils.distributed import ErrorHandler +from eole.utils.distributed_workers import spawned_train from eole.utils.misc import set_random_seed from eole.utils.logging import init_logger, logger from argparse import ArgumentParser diff --git a/eole/inference_engine.py b/eole/inference_engine.py index 4b19f4d2..af4a454b 100755 --- a/eole/inference_engine.py +++ b/eole/inference_engine.py @@ -1,7 +1,8 @@ import json from eole.constants import CorpusTask, DefaultTokens, ModelType from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter -from eole.utils.distributed import ErrorHandler, spawned_infer +from eole.utils.distributed import ErrorHandler +from eole.utils.distributed_workers import spawned_infer from eole.utils.logging import init_logger from eole.transforms import get_transforms_cls, make_transforms, TransformPipe diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index 8e4ec59a..53442695 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -9,7 +9,7 @@ from torch.utils.checkpoint import checkpoint from torch.nn.utils import skip_init from .alibi_position_bias import AlibiPositionalBias -from torch.distributed import all_reduce +from eole.utils.distributed import all_reduce_and_rescale_tensors from importlib import import_module from eole.constants import PositionEncodingType @@ -535,7 +535,7 @@ def _compute_attention( attn_output = self.maybe_ckpt(self.final_linear, context) if self.parallel_gpu > 1: - all_reduce(attn_output) + all_reduce_and_rescale_tensors(attn_output, 1) return attn_output, attn @@ -686,7 +686,7 @@ def forward( ).transpose(1, 2) attn_output = self.final_linear(unshape(context)) if self.parallel_gpu > 1: - all_reduce(attn_output) + all_reduce_and_rescale_tensors(attn_output, 1) return attn_output, None else: diff --git a/eole/modules/transformer_mlp.py b/eole/modules/transformer_mlp.py index ecb6bfee..e03291ba 100644 --- a/eole/modules/transformer_mlp.py +++ b/eole/modules/transformer_mlp.py @@ -4,7 +4,7 @@ from torch.utils.checkpoint import checkpoint from torch.nn.utils import skip_init -from torch.distributed import all_reduce +from eole.utils.distributed import all_reduce_and_rescale_tensors from eole.constants import ACTIVATION_FUNCTIONS @@ -81,7 +81,7 @@ def forward(self, x): mlp_out = self.dropout_2(mlp_out) if self.parallel_gpu > 1: - all_reduce(mlp_out) + all_reduce_and_rescale_tensors(mlp_out, 1.0) return mlp_out diff --git a/eole/utils/distributed.py b/eole/utils/distributed.py index e7ba2261..d7d59ba8 100644 --- a/eole/utils/distributed.py +++ b/eole/utils/distributed.py @@ -7,36 +7,6 @@ import math import pickle import torch.distributed -from datetime import timedelta -from eole.predict import build_predictor -from eole.transforms import get_transforms_cls, make_transforms -from eole.constants import CorpusTask -from eole.utils.logging import init_logger, logger -from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter - - -def is_master(config, device_id): - return config.gpu_ranks[device_id] == 0 - - -def multi_init(config, device_id): - # config is a running config here - dist_init_method = "tcp://{master_ip}:{master_port}".format( - master_ip=config.master_ip, master_port=config.master_port - ) - dist_world_size = config.world_size - torch.distributed.init_process_group( - backend=config.gpu_backend, - init_method=dist_init_method, - world_size=dist_world_size, - rank=config.gpu_ranks[device_id], - timeout=timedelta(seconds=config.timeout), - ) - gpu_rank = torch.distributed.get_rank() - if not is_master(config, device_id): - logger.disabled = True - - return gpu_rank def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=104857600): @@ -69,7 +39,10 @@ def all_reduce_buffer(): offset = 0 for t in buffer: numel = t.numel() - t.view(-1).copy_(buffer_t[offset : offset + numel]) + # t.view(-1).copy_(buffer_t[offset : offset + numel]) + t = ( + buffer_t[offset : offset + numel].view_as(t).clone() + ) # Clone to create a separate tensor offset += numel filled = 0 @@ -164,121 +137,3 @@ def signal_handler(self, signalnum, stackframe): be ignored --\n\n""" msg += original_trace raise Exception(msg) - - -def spawned_train(process_fn, config, device_id, error_queue): # noqa: E501 - """Run `process_fn` on `device_id` with data from `batch_queue`.""" - # config -> full config - try: - gpu_rank = multi_init(config.training, device_id) - if gpu_rank != config.training.gpu_ranks[device_id]: - raise AssertionError( - "An error occurred in \ - Distributed initialization" - ) - process_fn(config, device_id=device_id) - except KeyboardInterrupt: - pass # killed by parent, do nothing - except Exception: - # propagate exception to parent process, keeping original traceback - import traceback - - error_queue.put((config.training.gpu_ranks[device_id], traceback.format_exc())) - - -def spawned_infer( - config, device_id, error_queue, queue_instruct, queue_result, queue_settings=None -): - """Run various functions for prediction in spawned process on `device_id`.""" - try: - running_config = ( - config # will probably switch to config.inference at some point - ) - gpu_rank = multi_init(running_config, device_id) - if gpu_rank != running_config.gpu_ranks[device_id]: - raise AssertionError( - "An error occurred in \ - Distributed initialization" - ) - torch.cuda.set_device(device_id) - init_logger(config.log_file) - predictor = build_predictor(config, device_id, logger=logger, report_score=True) - transforms_cls = get_transforms_cls(config._all_transform) - transforms = make_transforms(config, transforms_cls, predictor.vocabs) - while True: - instruction = queue_instruct.get() - if queue_settings is not None: - settings = queue_settings.get() - predictor.update_settings(**settings) - if instruction[0] == "stop": - break - elif instruction[0] == "infer_list": - src = instruction[1] - infer_iter = build_dynamic_dataset_iter( - config, - transforms, - predictor.vocabs, - task=CorpusTask.INFER, - src=src, - device_id=device_id, - ) - scores, estims, preds = predictor._predict( - infer_iter, - infer_iter.transforms, - config.attn_debug, - config.align_debug, - ) - queue_result.put(scores) - queue_result.put(estims) - queue_result.put(preds) - elif instruction[0] == "infer_file": - config.src = instruction[1].src - infer_iter = build_dynamic_dataset_iter( - config, - transforms, - predictor.vocabs, - task=CorpusTask.INFER, - device_id=device_id, - ) - scores, estims, preds = predictor._predict( - infer_iter, - infer_iter.transforms, - config.attn_debug, - config.align_debug, - ) - queue_result.put(scores) - queue_result.put(estims) - queue_result.put(preds) - elif instruction[0] == "score_list": - tgt = instruction[1] - infer_iter = build_dynamic_dataset_iter( - config, - transforms, - predictor.vocabs, - task=CorpusTask.INFER, - src=tgt, - tgt=tgt, - device_id=device_id, - ) - score_results = predictor._score(infer_iter) - queue_result.put(score_results) - elif instruction[0] == "score_file": - config.src = instruction[1].src - config.tgt = instruction[1].src - infer_iter = build_dynamic_dataset_iter( - config, - transforms, - predictor.vocabs, - task=CorpusTask.INFER, - device_id=device_id, - ) - score_results = predictor._score(infer_iter) - queue_result.put(score_results) - - except KeyboardInterrupt: - pass # killed by parent, do nothing - except Exception: - # propagate exception to parent process, keeping original traceback - import traceback - - error_queue.put((running_config.gpu_ranks[device_id], traceback.format_exc())) diff --git a/eole/utils/distributed_workers.py b/eole/utils/distributed_workers.py new file mode 100644 index 00000000..6a58d7cf --- /dev/null +++ b/eole/utils/distributed_workers.py @@ -0,0 +1,153 @@ +""" Pytorch Distributed utils + This piece of code was heavily inspired by the equivalent of Fairseq-py + https://github.com/pytorch/fairseq +""" +import torch.distributed +from datetime import timedelta +from eole.predict import build_predictor +from eole.transforms import get_transforms_cls, make_transforms +from eole.constants import CorpusTask +from eole.utils.logging import init_logger, logger +from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter + + +def is_master(config, device_id): + return config.gpu_ranks[device_id] == 0 + + +def multi_init(config, device_id): + # config is a running config here + dist_init_method = "tcp://{master_ip}:{master_port}".format( + master_ip=config.master_ip, master_port=config.master_port + ) + dist_world_size = config.world_size + torch.distributed.init_process_group( + backend=config.gpu_backend, + init_method=dist_init_method, + world_size=dist_world_size, + rank=config.gpu_ranks[device_id], + timeout=timedelta(seconds=config.timeout), + ) + gpu_rank = torch.distributed.get_rank() + if not is_master(config, device_id): + logger.disabled = True + + return gpu_rank + + +def spawned_train(process_fn, config, device_id, error_queue): # noqa: E501 + """Run `process_fn` on `device_id` with data from `batch_queue`.""" + # config -> full config + try: + gpu_rank = multi_init(config.training, device_id) + if gpu_rank != config.training.gpu_ranks[device_id]: + raise AssertionError( + "An error occurred in \ + Distributed initialization" + ) + process_fn(config, device_id=device_id) + except KeyboardInterrupt: + pass # killed by parent, do nothing + except Exception: + # propagate exception to parent process, keeping original traceback + import traceback + + error_queue.put((config.training.gpu_ranks[device_id], traceback.format_exc())) + + +def spawned_infer( + config, device_id, error_queue, queue_instruct, queue_result, queue_settings=None +): + """Run various functions for prediction in spawned process on `device_id`.""" + try: + running_config = ( + config # will probably switch to config.inference at some point + ) + gpu_rank = multi_init(running_config, device_id) + if gpu_rank != running_config.gpu_ranks[device_id]: + raise AssertionError( + "An error occurred in \ + Distributed initialization" + ) + torch.cuda.set_device(device_id) + init_logger(config.log_file) + predictor = build_predictor(config, device_id, logger=logger, report_score=True) + transforms_cls = get_transforms_cls(config._all_transform) + transforms = make_transforms(config, transforms_cls, predictor.vocabs) + while True: + instruction = queue_instruct.get() + if queue_settings is not None: + settings = queue_settings.get() + predictor.update_settings(**settings) + if instruction[0] == "stop": + break + elif instruction[0] == "infer_list": + src = instruction[1] + infer_iter = build_dynamic_dataset_iter( + config, + transforms, + predictor.vocabs, + task=CorpusTask.INFER, + src=src, + device_id=device_id, + ) + scores, estims, preds = predictor._predict( + infer_iter, + infer_iter.transforms, + config.attn_debug, + config.align_debug, + ) + queue_result.put(scores) + queue_result.put(estims) + queue_result.put(preds) + elif instruction[0] == "infer_file": + config.src = instruction[1].src + infer_iter = build_dynamic_dataset_iter( + config, + transforms, + predictor.vocabs, + task=CorpusTask.INFER, + device_id=device_id, + ) + scores, estims, preds = predictor._predict( + infer_iter, + infer_iter.transforms, + config.attn_debug, + config.align_debug, + ) + queue_result.put(scores) + queue_result.put(estims) + queue_result.put(preds) + elif instruction[0] == "score_list": + tgt = instruction[1] + infer_iter = build_dynamic_dataset_iter( + config, + transforms, + predictor.vocabs, + task=CorpusTask.INFER, + src=tgt, + tgt=tgt, + device_id=device_id, + ) + score_results = predictor._score(infer_iter) + queue_result.put(score_results) + elif instruction[0] == "score_file": + config.src = instruction[1].src + config.tgt = instruction[1].src + infer_iter = build_dynamic_dataset_iter( + config, + transforms, + predictor.vocabs, + task=CorpusTask.INFER, + device_id=device_id, + ) + score_results = predictor._score(infer_iter) + queue_result.put(score_results) + + except KeyboardInterrupt: + pass # killed by parent, do nothing + except Exception: + # propagate exception to parent process, keeping original traceback + import traceback + + error_queue.put((running_config.gpu_ranks[device_id], traceback.format_exc())) diff --git a/recipes/wmt22_with_llama3.1/llama-instruct-inference.yaml b/recipes/wmt22_with_llama3.1/llama-instruct-inference.yaml index 3ee9e373..e84754db 100755 --- a/recipes/wmt22_with_llama3.1/llama-instruct-inference.yaml +++ b/recipes/wmt22_with_llama3.1/llama-instruct-inference.yaml @@ -1,5 +1,5 @@ # Model info -model_path: "${EOLE_MODEL_DIR}/llama3.1-70b-instruct" +model_path: "${EOLE_MODEL_DIR}/llama3.1-8b-instruct" # Inference seed: 42 @@ -12,8 +12,8 @@ batch_size: 512 world_size: 2 gpu_ranks: [0, 1] parallel_mode: "tensor_parallel" -quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] -quant_type: "bnb_NF4" +#quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] +#quant_type: "bnb_NF4" compute_dtype: fp16 top_k: 0 top_p: 0.0