Skip to content

Commit

Permalink
add kill switch file support to gracefully exit training at runtime (#…
Browse files Browse the repository at this point in the history
…412)

* Update arguments.py

* Update training.py

* Update utils.py

* add copyrights

* add copyrights

* add copyrights

* Update arguments.py help

* Update arguments.py

* Update training.py

* Update utils.py

* Update arguments.py
  • Loading branch information
polisettyvarma authored Jul 17, 2024
1 parent 73029ed commit fc989b8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,9 @@ def _add_network_size_args(parser):
help='Untie embeddings and output weights.'),
group.add_argument('--embedding-weights-in-fp32', action='store_true',
help='Cast word embedding weights to fp32 before embedding fwd.'),
group.add_argument('--kill-switch-file', type=str, default=None,
help='Location of kill switch file. '
'If found will automatically exit the program at runtime.')
return parser


Expand Down
19 changes: 17 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.utils import unwrap_model, found_kill_switch
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.core.pipeline_parallel import get_forward_backward_func
Expand Down Expand Up @@ -128,6 +128,13 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults, external_args=external_args)

args = get_args()

if found_kill_switch():
print_datetime(f"Detected kill switch at {args.kill_switch_file}. Exiting")
sys.exit()

# Set pytorch JIT layer fusion options and warmup JIT functions.
if get_accelerator().device_name() == 'cuda':
set_jit_fusion_options()
Expand All @@ -144,7 +151,6 @@ def pretrain(train_valid_test_dataset_provider,
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')

args = get_args()
timers = get_timers()

if args.deepspeed:
Expand Down Expand Up @@ -1358,6 +1364,15 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
sys.exit()
trigger(on_step_end)

# Exiting based on kill switch file
if found_kill_switch():
if args.save and not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
torch.distributed.barrier()
print_datetime(f"Detected kill switch at {args.kill_switch_file}, "
f"iteration={iteration}. Exiting")
sys.exit()

return iteration

Expand Down
8 changes: 8 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,11 @@ def dump_weights(preamble, iteration, model, optimizer, tensor=None):
p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param
fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n")


def found_kill_switch():
args = get_args()
if args.kill_switch_file is not None and os.path.exists(args.kill_switch_file):
return True
else:
return False

0 comments on commit fc989b8

Please sign in to comment.