diff --git a/megatron/arguments.py b/megatron/arguments.py index 30c3d669d7..af39f0b0e3 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -680,6 +680,9 @@ def _add_network_size_args(parser): help='Untie embeddings and output weights.'), group.add_argument('--embedding-weights-in-fp32', action='store_true', help='Cast word embedding weights to fp32 before embedding fwd.'), + group.add_argument('--kill-switch-file', type=str, default=None, + help='Location of kill switch file. ' + 'If found will automatically exit the program at runtime.') return parser diff --git a/megatron/training.py b/megatron/training.py index 697d62f7b8..6ba26f3944 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -43,7 +43,7 @@ from megatron.optimizer_param_scheduler import OptimizerParamScheduler from megatron.model import DistributedDataParallel as LocalDDP from megatron.utils import check_adlr_autoresume_termination -from megatron.utils import unwrap_model +from megatron.utils import unwrap_model, found_kill_switch from megatron.data.data_samplers import build_pretraining_data_loader from megatron.utils import calc_params_l2_norm from megatron.core.pipeline_parallel import get_forward_backward_func @@ -128,6 +128,13 @@ def pretrain(train_valid_test_dataset_provider, # Initalize and get arguments, timers, and Tensorboard writer. initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults, external_args=external_args) + + args = get_args() + + if found_kill_switch(): + print_datetime(f"Detected kill switch at {args.kill_switch_file}. Exiting") + sys.exit() + # Set pytorch JIT layer fusion options and warmup JIT functions. if get_accelerator().device_name() == 'cuda': set_jit_fusion_options() @@ -144,7 +151,6 @@ def pretrain(train_valid_test_dataset_provider, time.time() - _TRAIN_START_TIME)) print_datetime('after megatron is initialized') - args = get_args() timers = get_timers() if args.deepspeed: @@ -1358,6 +1364,15 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, sys.exit() trigger(on_step_end) + # Exiting based on kill switch file + if found_kill_switch(): + if args.save and not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + torch.distributed.barrier() + print_datetime(f"Detected kill switch at {args.kill_switch_file}, " + f"iteration={iteration}. Exiting") + sys.exit() return iteration diff --git a/megatron/utils.py b/megatron/utils.py index 9033d6402a..cbb7aa6426 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -381,3 +381,11 @@ def dump_weights(preamble, iteration, model, optimizer, tensor=None): p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n") + +def found_kill_switch(): + args = get_args() + if args.kill_switch_file is not None and os.path.exists(args.kill_switch_file): + return True + else: + return False +