Skip to content

Commit

Permalink
allow training resume from wandb ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
KuoHaoZeng committed Jul 10, 2024
1 parent 7e68e4a commit ad213e2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
9 changes: 9 additions & 0 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
TrainingPipeline,
set_deterministic_cudnn,
set_seed,
download_checkpoint_from_wandb,
)
from allenact.utils.system import get_logger
from allenact.utils.tensor_utils import batch_observations, detach_recursively
Expand Down Expand Up @@ -1900,6 +1901,14 @@ def train(
# noinspection PyBroadException
try:
if checkpoint_file_name is not None:
if "wandb://" == checkpoint_file_name[:8]:
ckpt_dir = "wandb_ckpts"
os.makedirs(ckpt_dir, exist_ok=True)
checkpoint_file_name = download_checkpoint_from_wandb(
checkpoint_path_dir_or_pattern,
ckpt_dir,
only_allow_one_ckpt=True
)
self.checkpoint_load(checkpoint_file_name, restart_pipeline)

self.run_pipeline(valid_on_initial_weights=valid_on_initial_weights)
Expand Down
19 changes: 2 additions & 17 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ScalarMeanTracker,
set_deterministic_cudnn,
set_seed,
download_checkpoint_from_wandb,
)
from allenact.utils.misc_utils import (
NumpyJSONEncoder,
Expand Down Expand Up @@ -1501,25 +1502,9 @@ def get_checkpoint_files(
approx_ckpt_step_interval: Optional[int] = None,
):
if "wandb://" == checkpoint_path_dir_or_pattern[:8]:
import wandb
import shutil
eval_dir = "wandb_ckpts_to_eval/{}".format(self.local_start_time_str)
os.makedirs(eval_dir, exist_ok=True)
api = wandb.Api()
run_token = checkpoint_path_dir_or_pattern.split("//")[1]
ckpt_steps = checkpoint_path_dir_or_pattern.split("//")[2:]
if ckpt_steps[-1] == "":
ckpt_steps = ckpt_steps[:-1]
ckpts_paths = []
for steps in ckpt_steps:
ckpt_fn = "{}-step-{}:latest".format(run_token, steps)
artifact = api.artifact(ckpt_fn)
_ = artifact.download("tmp")
ckpt_dir = "{}/ckpt-{}.pt".format(eval_dir, steps)
shutil.move("tmp/ckpt.pt", ckpt_dir)
ckpts_paths.append(ckpt_dir)
shutil.rmtree("tmp")
return ckpts_paths
return download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, eval_dir, only_allow_one_ckpt=False)

if os.path.isdir(checkpoint_path_dir_or_pattern):
# The fragment is a path to a directory, lets use this directory
Expand Down
30 changes: 30 additions & 0 deletions allenact/utils/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import numpy as np
import torch
import torch.optim as optim
import wandb
import shutil

from allenact.algorithms.offpolicy_sync.losses.abstract_offpolicy_loss import Memory
from allenact.algorithms.onpolicy_sync.losses.abstract_loss import (
Expand Down Expand Up @@ -1186,3 +1188,31 @@ def current_stage_losses(
)
for loss_name in self.current_stage.loss_names
}


def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir, only_allow_one_ckpt=False):
api = wandb.Api()
run_token = checkpoint_path_dir_or_pattern.split("//")[1]
ckpt_steps = checkpoint_path_dir_or_pattern.split("//")[2:]
if ckpt_steps[-1] == "":
ckpt_steps = ckpt_steps[:-1]
if not only_allow_one_ckpt:
ckpts_paths = []
for steps in ckpt_steps:
ckpt_fn = "{}-step-{}:latest".format(run_token, steps)
artifact = api.artifact(ckpt_fn)
_ = artifact.download("tmp")
ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps)
shutil.move("tmp/ckpt.pt", ckpt_dir)
ckpts_paths.append(ckpt_dir)
shutil.rmtree("tmp")
return ckpts_paths
else:
assert len(ckpt_steps) == 1
ckpt_fn = "{}-step-{}:latest".format(run_token, steps)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'steps' may be used before it is initialized.
artifact = api.artifact(ckpt_fn)
_ = artifact.download("tmp")
ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps)
shutil.move("tmp/ckpt.pt", ckpt_dir)
shutil.rmtree("tmp")
return ckpt_dir

0 comments on commit ad213e2

Please sign in to comment.