Skip to content

Commit

Permalink
Merge pull request #157 from edbeeching/sb3_example_add_lr_schedule
Browse files Browse the repository at this point in the history
Adds optional linear LR schedule to the example
  • Loading branch information
Ivan-267 authored Dec 4, 2023
2 parents 8227583 + 835a2e8 commit 010d05b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
17 changes: 13 additions & 4 deletions docs/ADV_STABLE_BASELINES_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ We recommend taking the [sb3 example](https://github.com/edbeeching/godot_rl_age

The example exposes more parameters for the user to configure, such as `--speedup` to run the environment faster than realtime and the `--n_parallel` to launch several instances of the game executable in order to accelerate training (not available for in-editor training).

## SB3 Example script usage:
To use the example script, first move to the location where the downloaded script is in the console/terminal, and then try some of the example use cases below:

### Train a model in editor:
Expand Down Expand Up @@ -78,20 +79,20 @@ You can optionally set an experiment directory and name to override the default.
python stable_baselines3_example.py --experiment_dir="experiments" --experiment_name="experiment1"
```

### Train a model for 100_000 steps then save and export the model
### Train a model for 100_000 steps then save and export the model:
The exported .onnx model can be used by the Godot sync node to run inference from Godot directly, while the saved .zip model can be used to resume training later or run inference from the example script by adding `--inference`.
```bash
python stable_baselines3_example.py --timesteps=100_000 --onnx_export_path=model.onnx --save_model_path=model.zip
```

### Resume training from a saved .zip model
### Resume training from a saved .zip model:
This will load the previously saved model.zip, and resume training for another 100 000 steps, so the saved model will have been trained for 200 000 steps in total.
Note that the console log will display the `total_timesteps` for the last training session only, so it will show `100000` instead of `200000`.
```bash
python stable_baselines3_example.py --timesteps=100_000 --save_model_path=model_200_000_total_steps.zip --resume_model_path=model.zip
```

### Save periodic checkpoints
### Save periodic checkpoints:
You can save periodic checkpoints and later resume training from any checkpoint using the same CL argument as above, or run inference on any checkpoint just like with the saved model.
Note that you need to use a unique `experiment_name` or `experiment_dir` for each run so that checkpoints from one run won't overwrite checkpoints from another run.
Alternatively, you can remove the folder containing checkpoints from a previous run if you don't need them anymore.
Expand All @@ -104,8 +105,16 @@ python stable_baselines3_example.py --experiment_name=experiment1 --timesteps=2_

Checkpoints will be saved to `logs\sb3\experiment1_checkpoints` in the above case, the location is affected by `--experiment_dir` and `--experiment_name`.

### Run inference on a saved model for 100_000 steps
### Run inference on a saved model for 100_000 steps:
You can run inference on a model that was previously saved using either `--save_model_path` or `--save_checkpoint_frequency`.
```bash
python stable_baselines3_example.py --timesteps=100_000 --resume_model_path=model.zip --inference
```

### Use a linear learning rate schedule:
By default, the learning rate will be constant throughout training.
If you add `--linear_lr_schedule`, learning rate will decrease with the progress,
and reach 0 at `--timesteps` value.
```bash
python stable_baselines3_example.py --timesteps=1_000_000 --linear_lr_schedule
```
46 changes: 43 additions & 3 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
import pathlib
from typing import Callable

from stable_baselines3.common.callbacks import CheckpointCallback
from godot_rl.core.utils import can_import
Expand All @@ -13,7 +14,7 @@
# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
# 2. chmod +x examples/godot_rl_BallChase/bin/BallChase.x86_64
if can_import("ray"):
print("WARNING, stable baselines and ray[rllib] are not compatable")
print("WARNING, stable baselines and ray[rllib] are not compatible")

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
Expand Down Expand Up @@ -85,11 +86,19 @@
help="Instead of training, it will run inference on a loaded model for --timesteps steps. "
"Requires --resume_model_path to be set."
)
parser.add_argument(
"--linear_lr_schedule",
default=False,
action="store_true",
help="Use a linear LR schedule for training. If set, learning rate will decrease until it reaches 0 at "
"--timesteps"
"value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used."
)
parser.add_argument(
"--viz",
action="store_true",
help="If set, the simulation will be displayed in a window during training. Otherwise "
"training will run without rendering the simualtion. This setting does not apply to in-editor training.",
"training will run without rendering the simulation. This setting does not apply to in-editor training.",
default=False
)
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
Expand Down Expand Up @@ -117,8 +126,39 @@
speedup=args.speedup)
env = VecMonitor(env)


# LR schedule code snippet from:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
"""
Linear learning rate schedule.
:param initial_value: Initial learning rate.
:return: schedule that computes
current learning rate depending on remaining progress
"""

def func(progress_remaining: float) -> float:
"""
Progress will decrease from 1 (beginning) to 0.
:param progress_remaining:
:return: current learning rate
"""
return progress_remaining * initial_value

return func


if args.resume_model_path is None:
model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir)
learning_rate = 0.0003 if not args.linear_lr_schedule else linear_schedule(0.0003)
model: PPO = PPO("MultiInputPolicy",
env,
ent_coef=0.0001,
verbose=2,
n_steps=32,
tensorboard_log=args.experiment_dir,
learning_rate=learning_rate)
else:
path_zip = pathlib.Path(args.resume_model_path)
print("Loading model: " + os.path.abspath(path_zip))
Expand Down

0 comments on commit 010d05b

Please sign in to comment.