Skip to content

Commit

Permalink
fix training tensor parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Sep 24, 2024
1 parent 759e226 commit 228ca1c
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 159 deletions.
3 changes: 2 additions & 1 deletion eole/bin/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""Train models with dynamic data."""
import torch
from functools import partial
from eole.utils.distributed import ErrorHandler, spawned_train
from eole.utils.distributed import ErrorHandler
from eole.utils.distributed_workers import spawned_train
from eole.utils.misc import set_random_seed
from eole.utils.logging import init_logger, logger
from argparse import ArgumentParser
Expand Down
3 changes: 2 additions & 1 deletion eole/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from eole.constants import CorpusTask, DefaultTokens, ModelType
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter
from eole.utils.distributed import ErrorHandler, spawned_infer
from eole.utils.distributed import ErrorHandler
from eole.utils.distributed_workers import spawned_infer
from eole.utils.logging import init_logger
from eole.transforms import get_transforms_cls, make_transforms, TransformPipe

Expand Down
6 changes: 3 additions & 3 deletions eole/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.checkpoint import checkpoint
from torch.nn.utils import skip_init
from .alibi_position_bias import AlibiPositionalBias
from torch.distributed import all_reduce
from eole.utils.distributed import all_reduce_and_rescale_tensors
from importlib import import_module
from eole.constants import PositionEncodingType

Expand Down Expand Up @@ -535,7 +535,7 @@ def _compute_attention(
attn_output = self.maybe_ckpt(self.final_linear, context)

if self.parallel_gpu > 1:
all_reduce(attn_output)
all_reduce_and_rescale_tensors(attn_output, 1)

return attn_output, attn

Expand Down Expand Up @@ -686,7 +686,7 @@ def forward(
).transpose(1, 2)
attn_output = self.final_linear(unshape(context))
if self.parallel_gpu > 1:
all_reduce(attn_output)
all_reduce_and_rescale_tensors(attn_output, 1)
return attn_output, None

else:
Expand Down
4 changes: 2 additions & 2 deletions eole/modules/transformer_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch.utils.checkpoint import checkpoint
from torch.nn.utils import skip_init
from torch.distributed import all_reduce
from eole.utils.distributed import all_reduce_and_rescale_tensors
from eole.constants import ACTIVATION_FUNCTIONS


Expand Down Expand Up @@ -81,7 +81,7 @@ def forward(self, x):
mlp_out = self.dropout_2(mlp_out)

if self.parallel_gpu > 1:
all_reduce(mlp_out)
all_reduce_and_rescale_tensors(mlp_out, 1.0)

return mlp_out

Expand Down
153 changes: 4 additions & 149 deletions eole/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,6 @@
import math
import pickle
import torch.distributed
from datetime import timedelta
from eole.predict import build_predictor
from eole.transforms import get_transforms_cls, make_transforms
from eole.constants import CorpusTask
from eole.utils.logging import init_logger, logger
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter


def is_master(config, device_id):
return config.gpu_ranks[device_id] == 0


def multi_init(config, device_id):
# config is a running config here
dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip=config.master_ip, master_port=config.master_port
)
dist_world_size = config.world_size
torch.distributed.init_process_group(
backend=config.gpu_backend,
init_method=dist_init_method,
world_size=dist_world_size,
rank=config.gpu_ranks[device_id],
timeout=timedelta(seconds=config.timeout),
)
gpu_rank = torch.distributed.get_rank()
if not is_master(config, device_id):
logger.disabled = True

return gpu_rank


def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=104857600):
Expand Down Expand Up @@ -69,7 +39,10 @@ def all_reduce_buffer():
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset : offset + numel])
# t.view(-1).copy_(buffer_t[offset : offset + numel])
t = (
buffer_t[offset : offset + numel].view_as(t).clone()
) # Clone to create a separate tensor
offset += numel

filled = 0
Expand Down Expand Up @@ -164,121 +137,3 @@ def signal_handler(self, signalnum, stackframe):
be ignored --\n\n"""
msg += original_trace
raise Exception(msg)


def spawned_train(process_fn, config, device_id, error_queue): # noqa: E501
"""Run `process_fn` on `device_id` with data from `batch_queue`."""
# config -> full config
try:
gpu_rank = multi_init(config.training, device_id)
if gpu_rank != config.training.gpu_ranks[device_id]:
raise AssertionError(
"An error occurred in \
Distributed initialization"
)
process_fn(config, device_id=device_id)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback

error_queue.put((config.training.gpu_ranks[device_id], traceback.format_exc()))


def spawned_infer(
config, device_id, error_queue, queue_instruct, queue_result, queue_settings=None
):
"""Run various functions for prediction in spawned process on `device_id`."""
try:
running_config = (
config # will probably switch to config.inference at some point
)
gpu_rank = multi_init(running_config, device_id)
if gpu_rank != running_config.gpu_ranks[device_id]:
raise AssertionError(
"An error occurred in \
Distributed initialization"
)
torch.cuda.set_device(device_id)
init_logger(config.log_file)
predictor = build_predictor(config, device_id, logger=logger, report_score=True)
transforms_cls = get_transforms_cls(config._all_transform)
transforms = make_transforms(config, transforms_cls, predictor.vocabs)
while True:
instruction = queue_instruct.get()
if queue_settings is not None:
settings = queue_settings.get()
predictor.update_settings(**settings)
if instruction[0] == "stop":
break
elif instruction[0] == "infer_list":
src = instruction[1]
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
src=src,
device_id=device_id,
)
scores, estims, preds = predictor._predict(
infer_iter,
infer_iter.transforms,
config.attn_debug,
config.align_debug,
)
queue_result.put(scores)
queue_result.put(estims)
queue_result.put(preds)
elif instruction[0] == "infer_file":
config.src = instruction[1].src
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
device_id=device_id,
)
scores, estims, preds = predictor._predict(
infer_iter,
infer_iter.transforms,
config.attn_debug,
config.align_debug,
)
queue_result.put(scores)
queue_result.put(estims)
queue_result.put(preds)
elif instruction[0] == "score_list":
tgt = instruction[1]
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
src=tgt,
tgt=tgt,
device_id=device_id,
)
score_results = predictor._score(infer_iter)
queue_result.put(score_results)
elif instruction[0] == "score_file":
config.src = instruction[1].src
config.tgt = instruction[1].src
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
device_id=device_id,
)
score_results = predictor._score(infer_iter)
queue_result.put(score_results)

except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback

error_queue.put((running_config.gpu_ranks[device_id], traceback.format_exc()))
153 changes: 153 additions & 0 deletions eole/utils/distributed_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
""" Pytorch Distributed utils
This piece of code was heavily inspired by the equivalent of Fairseq-py
https://github.com/pytorch/fairseq
"""
import torch.distributed
from datetime import timedelta
from eole.predict import build_predictor
from eole.transforms import get_transforms_cls, make_transforms
from eole.constants import CorpusTask
from eole.utils.logging import init_logger, logger
from eole.inputters.dynamic_iterator import build_dynamic_dataset_iter


def is_master(config, device_id):
return config.gpu_ranks[device_id] == 0


def multi_init(config, device_id):
# config is a running config here
dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip=config.master_ip, master_port=config.master_port
)
dist_world_size = config.world_size
torch.distributed.init_process_group(
backend=config.gpu_backend,
init_method=dist_init_method,
world_size=dist_world_size,
rank=config.gpu_ranks[device_id],
timeout=timedelta(seconds=config.timeout),
)
gpu_rank = torch.distributed.get_rank()
if not is_master(config, device_id):
logger.disabled = True

return gpu_rank


def spawned_train(process_fn, config, device_id, error_queue): # noqa: E501
"""Run `process_fn` on `device_id` with data from `batch_queue`."""
# config -> full config
try:
gpu_rank = multi_init(config.training, device_id)
if gpu_rank != config.training.gpu_ranks[device_id]:
raise AssertionError(
"An error occurred in \
Distributed initialization"
)
process_fn(config, device_id=device_id)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback

error_queue.put((config.training.gpu_ranks[device_id], traceback.format_exc()))


def spawned_infer(
config, device_id, error_queue, queue_instruct, queue_result, queue_settings=None
):
"""Run various functions for prediction in spawned process on `device_id`."""
try:
running_config = (
config # will probably switch to config.inference at some point
)
gpu_rank = multi_init(running_config, device_id)
if gpu_rank != running_config.gpu_ranks[device_id]:
raise AssertionError(
"An error occurred in \
Distributed initialization"
)
torch.cuda.set_device(device_id)
init_logger(config.log_file)
predictor = build_predictor(config, device_id, logger=logger, report_score=True)
transforms_cls = get_transforms_cls(config._all_transform)
transforms = make_transforms(config, transforms_cls, predictor.vocabs)
while True:
instruction = queue_instruct.get()
if queue_settings is not None:
settings = queue_settings.get()
predictor.update_settings(**settings)
if instruction[0] == "stop":
break
elif instruction[0] == "infer_list":
src = instruction[1]
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
src=src,
device_id=device_id,
)
scores, estims, preds = predictor._predict(
infer_iter,
infer_iter.transforms,
config.attn_debug,
config.align_debug,
)
queue_result.put(scores)
queue_result.put(estims)
queue_result.put(preds)
elif instruction[0] == "infer_file":
config.src = instruction[1].src
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
device_id=device_id,
)
scores, estims, preds = predictor._predict(
infer_iter,
infer_iter.transforms,
config.attn_debug,
config.align_debug,
)
queue_result.put(scores)
queue_result.put(estims)
queue_result.put(preds)
elif instruction[0] == "score_list":
tgt = instruction[1]
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
src=tgt,
tgt=tgt,
device_id=device_id,
)
score_results = predictor._score(infer_iter)
queue_result.put(score_results)
elif instruction[0] == "score_file":
config.src = instruction[1].src
config.tgt = instruction[1].src
infer_iter = build_dynamic_dataset_iter(
config,
transforms,
predictor.vocabs,
task=CorpusTask.INFER,
device_id=device_id,
)
score_results = predictor._score(infer_iter)
queue_result.put(score_results)

except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback

error_queue.put((running_config.gpu_ranks[device_id], traceback.format_exc()))
Loading

0 comments on commit 228ca1c

Please sign in to comment.