Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds saving model on ctrl + c in Python during training #160

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/ADV_STABLE_BASELINES_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ The exported .onnx model can be used by the Godot sync node to run inference fro
```bash
python stable_baselines3_example.py --timesteps=100_000 --onnx_export_path=model.onnx --save_model_path=model.zip
```
Note: If you interrupt/halt training using `ctrl + c`, it should save/export models before closing training (but only if you have included the corresponding arguments mentioned above). Using checkpoints (see below) is a safer way to keep progress.


### Resume training from a saved .zip model:
This will load the previously saved model.zip, and resume training for another 100 000 steps, so the saved model will have been trained for 200 000 steps in total.
Expand Down
55 changes: 36 additions & 19 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,31 @@
"launch - requires --env_path to be set if > 1.")
args, extras = parser.parse_known_args()


def handle_onnx_export():
# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))


def handle_model_save():
if args.save_model_path is not None:
zip_save_path = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(zip_save_path))
model.save(zip_save_path)


def close_env():
try:
print("closing env")
env.close()
except Exception as e:
print("Exception while closing env: ", e)


path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)

Expand Down Expand Up @@ -170,28 +195,20 @@ def func(progress_remaining: float) -> float:
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
if args.save_checkpoint_frequency is None:
model.learn(args.timesteps, tb_log_name=args.experiment_name)
else:
learn_arguments = dict(total_timesteps=args.timesteps, tb_log_name=args.experiment_name)
if args.save_checkpoint_frequency:
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint)
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=args.experiment_name
)
model.learn(args.timesteps, callback=checkpoint_callback, tb_log_name=args.experiment_name)

print("closing env")
env.close()

# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))

if args.save_model_path is not None:
path_zip = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(path_zip))
model.save(path_zip)
learn_arguments['callback'] = checkpoint_callback
try:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print("Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used.")

close_env()
handle_onnx_export()
handle_model_save()