Skip to content

Commit

Permalink
make ckpt saving at every host an option
Browse files Browse the repository at this point in the history
  • Loading branch information
KuoHaoZeng committed Jul 15, 2024
1 parent ecc24a9 commit 52dccc8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
9 changes: 4 additions & 5 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
TrainingPipeline,
set_deterministic_cudnn,
set_seed,
download_checkpoint_from_wandb,
)
from allenact.utils.system import get_logger
from allenact.utils.tensor_utils import batch_observations, detach_recursively
Expand Down Expand Up @@ -1176,6 +1175,7 @@ def __init__(
max_sampler_processes_per_worker: Optional[int] = None,
save_ckpt_after_every_pipeline_stage: bool = True,
first_local_worker_id: int = 0,
save_ckpt_at_every_host: bool = False,
**kwargs,
):
kwargs["mode"] = TRAIN_MODE_STR
Expand Down Expand Up @@ -1267,6 +1267,7 @@ def __init__(
)

self.first_local_worker_id = first_local_worker_id
self.save_ckpt_at_every_host = save_ckpt_at_every_host

def advance_seed(
self, seed: Optional[int], return_same_seed_per_worker=False
Expand Down Expand Up @@ -1539,8 +1540,7 @@ def _save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_co
):
model_path = None
self.deterministic_seeds()
# if self.worker_id == self.first_local_worker_id:
if self.worker_id == 0:
if (self.save_ckpt_at_every_host and self.worker_id == self.first_local_worker_id) or self.worker_id == 0:
model_path = self.checkpoint_save(pipeline_stage_index=pipeline_stage_index)
if self.checkpoints_queue is not None:
self.checkpoints_queue.put(("eval", model_path))
Expand Down Expand Up @@ -1581,8 +1581,7 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
and should_save_checkpoints
and self.checkpoints_queue is not None
):
# if self.worker_id == self.first_local_worker_id:
if self.worker_id == 0:
if (self.save_ckpt_at_every_host and self.worker_id == self.first_local_worker_id) or self.worker_id == 0:
model_path = self.checkpoint_save()
if self.checkpoints_queue is not None:
self.checkpoints_queue.put(("eval", model_path))
Expand Down
2 changes: 2 additions & 0 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def start_train(
collect_valid_results: bool = False,
valid_on_initial_weights: bool = False,
try_restart_after_task_error: bool = False,
save_ckpt_at_every_host: bool = False,
):
self._initialize_start_train_or_start_test()

Expand Down Expand Up @@ -574,6 +575,7 @@ def start_train(
distributed_preemption_threshold=self.distributed_preemption_threshold,
valid_on_initial_weights=valid_on_initial_weights,
try_restart_after_task_error=try_restart_after_task_error,
save_ckpt_at_every_host=save_ckpt_at_every_host,
)
train: BaseProcess = self.mp_ctx.Process(
target=self.train_loop,
Expand Down
11 changes: 11 additions & 0 deletions allenact/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ def get_argument_parser():
" tutorial https://allenact.org/tutorials/distributed-objectnav-tutorial/",
)

parser.add_argument(
"--save_ckpt_at_every_host",
dest="save_ckpt_at_every_host",
action="store_true",
required=False,
help="if you pass the `--save_ckpt_at_every_host` flag, AllenAct will save checkpoints at every host as the"
" the training progresses in distributed training mode.",
)
parser.set_defaults(save_ckpt_at_every_host=False)

parser.add_argument(
"--callbacks",
dest="callbacks",
Expand Down Expand Up @@ -484,6 +494,7 @@ def main():
collect_valid_results=args.collect_valid_results,
valid_on_initial_weights=args.valid_on_initial_weights,
try_restart_after_task_error=args.enable_crash_recovery,
save_ckpt_at_every_host=save_ckpt_at_every_host,
)
else:
OnPolicyRunner(
Expand Down
8 changes: 4 additions & 4 deletions allenact/utils/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,17 +1201,17 @@ def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir,
for steps in ckpt_steps:
ckpt_fn = "{}-step-{}:latest".format(run_token, steps)
artifact = api.artifact(ckpt_fn)
_ = artifact.download("/tmp")
_ = artifact.download(all_ckpt_dir)
ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps)
shutil.move("/tmp/ckpt.pt", ckpt_dir)
shutil.move("{}/ckpt.pt".format(all_ckpt_dir), ckpt_dir)
ckpts_paths.append(ckpt_dir)
return ckpts_paths
else:
assert len(ckpt_steps) == 1
step = ckpt_steps[0]
ckpt_fn = "{}-step-{}:latest".format(run_token, step)
artifact = api.artifact(ckpt_fn)
_ = artifact.download("/tmp")
_ = artifact.download(all_ckpt_dir)
ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, step)
shutil.move("/tmp/ckpt.pt", ckpt_dir)
shutil.move("{}/ckpt.pt".format(all_ckpt_dir), ckpt_dir)
return ckpt_dir

0 comments on commit 52dccc8

Please sign in to comment.