Skip to content

Commit

Permalink
Adds saving model on ctrl + c in Python during training
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-267 committed Dec 4, 2023
1 parent 010d05b commit b35766c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/ADV_STABLE_BASELINES_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ The exported .onnx model can be used by the Godot sync node to run inference fro
```bash
python stable_baselines3_example.py --timesteps=100_000 --onnx_export_path=model.onnx --save_model_path=model.zip
```
Note: If you interrupt/halt training using `ctrl + c`, it should save/export models before closing training (but only if you have included the corresponding arguments mentioned above). Using checkpoints (see below) is a safer way to keep progress.


### Resume training from a saved .zip model:
This will load the previously saved model.zip, and resume training for another 100 000 steps, so the saved model will have been trained for 200 000 steps in total.
Expand Down
57 changes: 38 additions & 19 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,33 @@
"launch - requires --env_path to be set if > 1.")
args, extras = parser.parse_known_args()


def handle_onnx_export():
# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))
pass


def handle_model_save():
if args.save_model_path is not None:
zip_save_path = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(zip_save_path))
model.save(zip_save_path)
pass


def close_env():
try:
print("closing env")
env.close()
except Exception as e:
print("Exception while closing env: ", e)


path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)

Expand Down Expand Up @@ -170,28 +197,20 @@ def func(progress_remaining: float) -> float:
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
if args.save_checkpoint_frequency is None:
model.learn(args.timesteps, tb_log_name=args.experiment_name)
else:
learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.experiment_name)
if args.save_checkpoint_frequency:
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint)
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=args.experiment_name
)
model.learn(args.timesteps, callback=checkpoint_callback, tb_log_name=args.experiment_name)

print("closing env")
env.close()

# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))

if args.save_model_path is not None:
path_zip = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(path_zip))
model.save(path_zip)
learn_arguments['callback'] = checkpoint_callback
try:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print("Training interrupted by user. Will save/export if --save_model_path or --onnx_export_path was used.")

close_env()
handle_onnx_export()
handle_model_save()

0 comments on commit b35766c

Please sign in to comment.