From ad213e233b83ee20f75bd0b75d0b10e3e2f085ce Mon Sep 17 00:00:00 2001 From: KuoHaoZeng Date: Wed, 10 Jul 2024 15:13:04 -0700 Subject: [PATCH] allow training resume from wandb ckpt --- allenact/algorithms/onpolicy_sync/engine.py | 9 +++++++ allenact/algorithms/onpolicy_sync/runner.py | 19 ++----------- allenact/utils/experiment_utils.py | 30 +++++++++++++++++++++ 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index b481afc8..cf2783f5 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -66,6 +66,7 @@ TrainingPipeline, set_deterministic_cudnn, set_seed, + download_checkpoint_from_wandb, ) from allenact.utils.system import get_logger from allenact.utils.tensor_utils import batch_observations, detach_recursively @@ -1900,6 +1901,14 @@ def train( # noinspection PyBroadException try: if checkpoint_file_name is not None: + if "wandb://" == checkpoint_file_name[:8]: + ckpt_dir = "wandb_ckpts" + os.makedirs(ckpt_dir, exist_ok=True) + checkpoint_file_name = download_checkpoint_from_wandb( + checkpoint_path_dir_or_pattern, + ckpt_dir, + only_allow_one_ckpt=True + ) self.checkpoint_load(checkpoint_file_name, restart_pipeline) self.run_pipeline(valid_on_initial_weights=valid_on_initial_weights) diff --git a/allenact/algorithms/onpolicy_sync/runner.py b/allenact/algorithms/onpolicy_sync/runner.py index dd33e956..8a02bba2 100644 --- a/allenact/algorithms/onpolicy_sync/runner.py +++ b/allenact/algorithms/onpolicy_sync/runner.py @@ -45,6 +45,7 @@ ScalarMeanTracker, set_deterministic_cudnn, set_seed, + download_checkpoint_from_wandb, ) from allenact.utils.misc_utils import ( NumpyJSONEncoder, @@ -1501,25 +1502,9 @@ def get_checkpoint_files( approx_ckpt_step_interval: Optional[int] = None, ): if "wandb://" == checkpoint_path_dir_or_pattern[:8]: - import wandb - import shutil eval_dir = "wandb_ckpts_to_eval/{}".format(self.local_start_time_str) os.makedirs(eval_dir, exist_ok=True) - api = wandb.Api() - run_token = checkpoint_path_dir_or_pattern.split("//")[1] - ckpt_steps = checkpoint_path_dir_or_pattern.split("//")[2:] - if ckpt_steps[-1] == "": - ckpt_steps = ckpt_steps[:-1] - ckpts_paths = [] - for steps in ckpt_steps: - ckpt_fn = "{}-step-{}:latest".format(run_token, steps) - artifact = api.artifact(ckpt_fn) - _ = artifact.download("tmp") - ckpt_dir = "{}/ckpt-{}.pt".format(eval_dir, steps) - shutil.move("tmp/ckpt.pt", ckpt_dir) - ckpts_paths.append(ckpt_dir) - shutil.rmtree("tmp") - return ckpts_paths + return download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, eval_dir, only_allow_one_ckpt=False) if os.path.isdir(checkpoint_path_dir_or_pattern): # The fragment is a path to a directory, lets use this directory diff --git a/allenact/utils/experiment_utils.py b/allenact/utils/experiment_utils.py index 87f6bbba..99525743 100644 --- a/allenact/utils/experiment_utils.py +++ b/allenact/utils/experiment_utils.py @@ -26,6 +26,8 @@ import numpy as np import torch import torch.optim as optim +import wandb +import shutil from allenact.algorithms.offpolicy_sync.losses.abstract_offpolicy_loss import Memory from allenact.algorithms.onpolicy_sync.losses.abstract_loss import ( @@ -1186,3 +1188,31 @@ def current_stage_losses( ) for loss_name in self.current_stage.loss_names } + + +def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir, only_allow_one_ckpt=False): + api = wandb.Api() + run_token = checkpoint_path_dir_or_pattern.split("//")[1] + ckpt_steps = checkpoint_path_dir_or_pattern.split("//")[2:] + if ckpt_steps[-1] == "": + ckpt_steps = ckpt_steps[:-1] + if not only_allow_one_ckpt: + ckpts_paths = [] + for steps in ckpt_steps: + ckpt_fn = "{}-step-{}:latest".format(run_token, steps) + artifact = api.artifact(ckpt_fn) + _ = artifact.download("tmp") + ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps) + shutil.move("tmp/ckpt.pt", ckpt_dir) + ckpts_paths.append(ckpt_dir) + shutil.rmtree("tmp") + return ckpts_paths + else: + assert len(ckpt_steps) == 1 + ckpt_fn = "{}-step-{}:latest".format(run_token, steps) + artifact = api.artifact(ckpt_fn) + _ = artifact.download("tmp") + ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps) + shutil.move("tmp/ckpt.pt", ckpt_dir) + shutil.rmtree("tmp") + return ckpt_dir \ No newline at end of file