From 52dccc88ff8a50124e4e3d65ef49a6202398ca59 Mon Sep 17 00:00:00 2001 From: KuoHaoZeng Date: Mon, 15 Jul 2024 10:45:37 -0700 Subject: [PATCH] make ckpt saving at every host an option --- allenact/algorithms/onpolicy_sync/engine.py | 9 ++++----- allenact/algorithms/onpolicy_sync/runner.py | 2 ++ allenact/main.py | 11 +++++++++++ allenact/utils/experiment_utils.py | 8 ++++---- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index fa4390de..64b8e164 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -66,7 +66,6 @@ TrainingPipeline, set_deterministic_cudnn, set_seed, - download_checkpoint_from_wandb, ) from allenact.utils.system import get_logger from allenact.utils.tensor_utils import batch_observations, detach_recursively @@ -1176,6 +1175,7 @@ def __init__( max_sampler_processes_per_worker: Optional[int] = None, save_ckpt_after_every_pipeline_stage: bool = True, first_local_worker_id: int = 0, + save_ckpt_at_every_host: bool = False, **kwargs, ): kwargs["mode"] = TRAIN_MODE_STR @@ -1267,6 +1267,7 @@ def __init__( ) self.first_local_worker_id = first_local_worker_id + self.save_ckpt_at_every_host = save_ckpt_at_every_host def advance_seed( self, seed: Optional[int], return_same_seed_per_worker=False @@ -1539,8 +1540,7 @@ def _save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_co ): model_path = None self.deterministic_seeds() - # if self.worker_id == self.first_local_worker_id: - if self.worker_id == 0: + if (self.save_ckpt_at_every_host and self.worker_id == self.first_local_worker_id) or self.worker_id == 0: model_path = self.checkpoint_save(pipeline_stage_index=pipeline_stage_index) if self.checkpoints_queue is not None: self.checkpoints_queue.put(("eval", model_path)) @@ -1581,8 +1581,7 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): and should_save_checkpoints and self.checkpoints_queue is not None ): - # if self.worker_id == self.first_local_worker_id: - if self.worker_id == 0: + if (self.save_ckpt_at_every_host and self.worker_id == self.first_local_worker_id) or self.worker_id == 0: model_path = self.checkpoint_save() if self.checkpoints_queue is not None: self.checkpoints_queue.put(("eval", model_path)) diff --git a/allenact/algorithms/onpolicy_sync/runner.py b/allenact/algorithms/onpolicy_sync/runner.py index f5f3bd64..021a36e0 100644 --- a/allenact/algorithms/onpolicy_sync/runner.py +++ b/allenact/algorithms/onpolicy_sync/runner.py @@ -501,6 +501,7 @@ def start_train( collect_valid_results: bool = False, valid_on_initial_weights: bool = False, try_restart_after_task_error: bool = False, + save_ckpt_at_every_host: bool = False, ): self._initialize_start_train_or_start_test() @@ -574,6 +575,7 @@ def start_train( distributed_preemption_threshold=self.distributed_preemption_threshold, valid_on_initial_weights=valid_on_initial_weights, try_restart_after_task_error=try_restart_after_task_error, + save_ckpt_at_every_host=save_ckpt_at_every_host, ) train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, diff --git a/allenact/main.py b/allenact/main.py index 138b5c6f..cfb85250 100755 --- a/allenact/main.py +++ b/allenact/main.py @@ -274,6 +274,16 @@ def get_argument_parser(): " tutorial https://allenact.org/tutorials/distributed-objectnav-tutorial/", ) + parser.add_argument( + "--save_ckpt_at_every_host", + dest="save_ckpt_at_every_host", + action="store_true", + required=False, + help="if you pass the `--save_ckpt_at_every_host` flag, AllenAct will save checkpoints at every host as the" + " the training progresses in distributed training mode.", + ) + parser.set_defaults(save_ckpt_at_every_host=False) + parser.add_argument( "--callbacks", dest="callbacks", @@ -484,6 +494,7 @@ def main(): collect_valid_results=args.collect_valid_results, valid_on_initial_weights=args.valid_on_initial_weights, try_restart_after_task_error=args.enable_crash_recovery, + save_ckpt_at_every_host=save_ckpt_at_every_host, ) else: OnPolicyRunner( diff --git a/allenact/utils/experiment_utils.py b/allenact/utils/experiment_utils.py index a20ef6d4..0ace2770 100644 --- a/allenact/utils/experiment_utils.py +++ b/allenact/utils/experiment_utils.py @@ -1201,9 +1201,9 @@ def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir, for steps in ckpt_steps: ckpt_fn = "{}-step-{}:latest".format(run_token, steps) artifact = api.artifact(ckpt_fn) - _ = artifact.download("/tmp") + _ = artifact.download(all_ckpt_dir) ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps) - shutil.move("/tmp/ckpt.pt", ckpt_dir) + shutil.move("{}/ckpt.pt".format(all_ckpt_dir), ckpt_dir) ckpts_paths.append(ckpt_dir) return ckpts_paths else: @@ -1211,7 +1211,7 @@ def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir, step = ckpt_steps[0] ckpt_fn = "{}-step-{}:latest".format(run_token, step) artifact = api.artifact(ckpt_fn) - _ = artifact.download("/tmp") + _ = artifact.download(all_ckpt_dir) ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, step) - shutil.move("/tmp/ckpt.pt", ckpt_dir) + shutil.move("{}/ckpt.pt".format(all_ckpt_dir), ckpt_dir) return ckpt_dir \ No newline at end of file