From af92aec7eab4f772d386d7336226f60378614a60 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:21:04 +0100 Subject: [PATCH 1/9] adds tests for code quality / formatting --- .github/workflows/quality.yml | 29 +++++++++++++++++++++++++++++ Makefile | 15 ++++++++------- setup.cfg | 1 - 3 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/quality.yml diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml new file mode 100644 index 00000000..a790eafc --- /dev/null +++ b/.github/workflows/quality.yml @@ -0,0 +1,29 @@ +name: Quality + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + + check_code_quality: + name: Check code quality + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Setup Python environment + uses: actions/setup-python@v2 + with: + python-version: 3.10.10 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install ".[dev]" + - name: Code quality + run: | + make quality \ No newline at end of file diff --git a/Makefile b/Makefile index 0aa8753e..d5cf222d 100644 --- a/Makefile +++ b/Makefile @@ -1,15 +1,16 @@ .PHONY: quality style test unity-test -# Check that source code meets quality standards -quality: - black --check --line-length 119 --target-version py38 tests godot_rl - isort --check-only tests godot_rl - flake8 tests godot_rl +check_dirs := src tests godot_rl # Format source code automatically style: - black --line-length 119 --target-version py38 tests godot_rl - isort tests godot_rl + python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py + python -m isort $(check_dirs) setup.py +# Check that source code meets quality standards +quality: + python -m black --check --line-length 119 --target-version py310 $(check_dirs) setup.py + python -m isort --check-only $(check_dirs) setup.py + python -m flake8 --max-line-length 119 $(check_dirs) setup.py # Run tests for the library test: diff --git a/setup.cfg b/setup.cfg index c0329fec..319bb957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,6 @@ console_scripts = test = pytest>=6.0 pytest-xdist - dev = pytest>=6.0 pytest-xdist From b1999bd59456f750c20dda57b723c44aaec59917 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:31:35 +0100 Subject: [PATCH 2/9] adds contrib guide --- README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/README.md b/README.md index 7ecfd696..1c92926e 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,33 @@ Godot RL Agents supports 4 different RL training frameworks, the links below det - [CleanRL](docs/ADV_CLEAN_RL.md) (Windows, Mac, Linux) - [Ray rllib](docs/ADV_RLLIB.md) (Windows, Mac, Linux) +## Contributing +We welcome new contributions to the library, such as: +- New environments made in Godot +- Improvements to the readme files +- Additions to the python codebase + +Start by forking the repo and then cloning it to your machine, creating a venv and performing an editable installation. + +``` +# If you want to PR, you should fork the lib or ask to be a contibutor +git clone git@github.com:YOUR_USERNAME/godot_rl_agents.git +cd godot_rl_agents +python -m venv venv +pip install -e ".[dev]" +# check tests run +make test +``` + +Then add your features. +Format your code with: +``` +make style +make quality +``` +Then make a PR against main on the original repo. + + ## FAQ ### Why have we developed Godot RL Agents? From 7e262eb507cbf3d7a97dd9fa21b0d6435cb1c67f Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:52:03 +0100 Subject: [PATCH 3/9] updates gitignore --- .vscode/settings.json | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 7c9f21cc..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "editor.formatOnSave": true, - "python.formatting.provider": "black", - "python.formatting.blackArgs": [ - "-l 120" - ], - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, -} \ No newline at end of file From 2da011fbe14145f88a4827f5171edbf3a7c540de Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:52:10 +0100 Subject: [PATCH 4/9] formats tests --- tests/test_action_space_preprocessor.py | 21 ++++++++++----------- tests/test_godot_env.py | 16 ++++------------ tests/test_rllib.py | 5 +++-- tests/test_sample_factory.py | 13 +++++++------ tests/test_sb3_onnx_export.py | 3 ++- tests/test_sb3_training.py | 6 ++---- 6 files changed, 28 insertions(+), 36 deletions(-) diff --git a/tests/test_action_space_preprocessor.py b/tests/test_action_space_preprocessor.py index 4496bbe5..e4f0acbf 100644 --- a/tests/test_action_space_preprocessor.py +++ b/tests/test_action_space_preprocessor.py @@ -1,26 +1,25 @@ import pytest -from gymnasium.spaces import Tuple, Dict, Box, Discrete +from gymnasium.spaces import Box, Dict, Discrete, Tuple + from godot_rl.core.godot_env import GodotEnv from godot_rl.core.utils import ActionSpaceProcessor -@pytest.mark.parametrize("action_space", + +@pytest.mark.parametrize( + "action_space", [ - Tuple([Box(-1,1, shape=[7]), Box(-1,1, shape=[11])]), - Tuple([Box(-1,1, shape=[7]), Discrete(2)]), + Tuple([Box(-1, 1, shape=[7]), Box(-1, 1, shape=[11])]), + Tuple([Box(-1, 1, shape=[7]), Discrete(2)]), Tuple([Discrete(2), Discrete(2)]), - Tuple([Discrete(2), Discrete(2), Box(-1,1, shape=[11])]), - ] - - - + Tuple([Discrete(2), Discrete(2), Box(-1, 1, shape=[11])]), + ], ) def test_action_space_preprocessor(action_space): - expected_output = 0 for space in action_space.spaces: if isinstance(space, Box): - assert len(space.shape) ==1 + assert len(space.shape) == 1 expected_output += space.shape[0] elif isinstance(space, Discrete): if space.n > 2: diff --git a/tests/test_godot_env.py b/tests/test_godot_env.py index 2b9435af..9f5532a7 100644 --- a/tests/test_godot_env.py +++ b/tests/test_godot_env.py @@ -37,12 +37,8 @@ def test_env_ij(env_name, port, n_agents): assert isinstance( reward[0], (float, int) ), f"The reward returned by 'step()' must be a float or int, and is {reward[0]} of type {type(reward[0])}" - assert isinstance( - term[0], bool - ), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean" - assert isinstance( - info[0], dict - ), "The 'info' returned by 'step()' must be a python dictionary" + assert isinstance(term[0], bool), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean" + assert isinstance(info[0], dict), "The 'info' returned by 'step()' must be a python dictionary" env.close() @@ -82,11 +78,7 @@ def test_env_ji(env_name, port, n_agents): assert isinstance( reward[0], (float, int) ), f"The reward returned by 'step()' must be a float or int, and is {reward[0]} of type {type(reward[0])}" - assert isinstance( - term[0], bool - ), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean" - assert isinstance( - info[0], dict - ), "The 'info' returned by 'step()' must be a python dictionary" + assert isinstance(term[0], bool), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean" + assert isinstance(info[0], dict), "The 'info' returned by 'step()' must be a python dictionary" env.close() diff --git a/tests/test_rllib.py b/tests/test_rllib.py index 574d2b82..3ed48a0e 100644 --- a/tests/test_rllib.py +++ b/tests/test_rllib.py @@ -2,13 +2,14 @@ from godot_rl.core.utils import cant_import + @pytest.mark.skipif(cant_import("ray"), reason="ray[rllib] is not available") def test_rllib_training(): - from godot_rl.wrappers.ray_wrapper import rllib_training from godot_rl.main import get_args + from godot_rl.wrappers.ray_wrapper import rllib_training + args, extras = get_args() args.config_file = "tests/fixtures/test_rllib.yaml" args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" - rllib_training(args, extras) diff --git a/tests/test_sample_factory.py b/tests/test_sample_factory.py index eaa9c826..86e72b4e 100644 --- a/tests/test_sample_factory.py +++ b/tests/test_sample_factory.py @@ -2,16 +2,17 @@ from godot_rl.core.utils import cant_import + @pytest.mark.skipif(cant_import("sample_factory"), reason="sample_factory is not available") def test_sample_factory_training(): - from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training from examples.sample_factory_example import get_args + from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training + args, extras = get_args() args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" extras = [] - extras.append('--env=gdrl') - extras.append('--train_for_env_steps=1000') - extras.append('--device=cpu') - + extras.append("--env=gdrl") + extras.append("--train_for_env_steps=1000") + extras.append("--device=cpu") + sample_factory_training(args, extras) - diff --git a/tests/test_sb3_onnx_export.py b/tests/test_sb3_onnx_export.py index d3160492..d452518d 100644 --- a/tests/test_sb3_onnx_export.py +++ b/tests/test_sb3_onnx_export.py @@ -22,8 +22,9 @@ ) def test_pytorch_vs_onnx(env_name, port): from stable_baselines3 import PPO - from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv + from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx, verify_onnx_export + from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64" env = StableBaselinesGodotEnv(env_path, port=port) diff --git a/tests/test_sb3_training.py b/tests/test_sb3_training.py index 4b72cbc9..e0561a54 100644 --- a/tests/test_sb3_training.py +++ b/tests/test_sb3_training.py @@ -1,7 +1,7 @@ import pytest -from godot_rl.main import get_args from godot_rl.core.utils import can_import +from godot_rl.main import get_args @pytest.mark.skipif(can_import("ray"), reason="rllib and sb3 are not compatable") @@ -30,6 +30,4 @@ def test_sb3_training(env_name, port, n_parallel): args.speedup = 8 starting_port = port + n_parallel - stable_baselines_training( - args, extras, n_steps=2, port=starting_port, n_parallel=n_parallel - ) + stable_baselines_training(args, extras, n_steps=2, port=starting_port, n_parallel=n_parallel) From 10d4fef92b93371c26680603983af7fad379f0a8 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:52:23 +0100 Subject: [PATCH 5/9] formats examples --- examples/clean_rl_example.py | 18 +++--- examples/sample_factory_example.py | 13 ++-- examples/stable_baselines3_example.py | 84 ++++++++++++++----------- examples/stable_baselines3_hp_tuning.py | 31 ++++++--- 4 files changed, 87 insertions(+), 59 deletions(-) diff --git a/examples/clean_rl_example.py b/examples/clean_rl_example.py index 8b061fb8..453f99ac 100644 --- a/examples/clean_rl_example.py +++ b/examples/clean_rl_example.py @@ -167,8 +167,9 @@ def get_action_and_value(self, x, action=None): # env setup - envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, - n_parallel=args.n_parallel) + envs = env = CleanRLGodotEnv( + env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel + ) args.num_envs = envs.num_envs args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) @@ -334,7 +335,6 @@ def get_action_and_value(self, x, action=None): agent.eval().to("cpu") - class OnnxPolicy(torch.nn.Module): def __init__(self, actor_mean): super().__init__() @@ -344,7 +344,6 @@ def forward(self, obs, state_ins): action_mean = self.actor_mean(obs) return action_mean, state_ins - onnx_policy = OnnxPolicy(agent.actor_mean) dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0) @@ -355,9 +354,10 @@ def forward(self, obs, state_ins): opset_version=15, input_names=["obs", "state_ins"], output_names=["output", "state_outs"], - dynamic_axes={'obs': {0: 'batch_size'}, - 'state_ins': {0: 'batch_size'}, # variable length axes - 'output': {0: 'batch_size'}, - 'state_outs': {0: 'batch_size'}} - + dynamic_axes={ + "obs": {0: "batch_size"}, + "state_ins": {0: "batch_size"}, # variable length axes + "output": {0: "batch_size"}, + "state_outs": {0: "batch_size"}, + }, ) diff --git a/examples/sample_factory_example.py b/examples/sample_factory_example.py index 2c4e10a6..9ce243cf 100644 --- a/examples/sample_factory_example.py +++ b/examples/sample_factory_example.py @@ -10,8 +10,12 @@ def get_args(): parser.add_argument("--seed", default=0, type=int, help="environment seed") parser.add_argument("--export", default=False, action="store_true", help="whether to export the model") parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") - parser.add_argument("--experiment_dir", default="logs/sf", type=str, - help="The name of the experiment directory, in which the tensorboard logs are getting stored") + parser.add_argument( + "--experiment_dir", + default="logs/sf", + type=str, + help="The name of the experiment directory, in which the tensorboard logs are getting stored", + ) parser.add_argument( "--experiment_name", default="experiment", @@ -22,14 +26,13 @@ def get_args(): return parser.parse_known_args() - def main(): args, extras = get_args() if args.eval: sample_factory_enjoy(args, extras) else: sample_factory_training(args, extras) - - + + if __name__ == "__main__": main() diff --git a/examples/stable_baselines3_example.py b/examples/stable_baselines3_example.py index 7358f991..3d4e90d6 100644 --- a/examples/stable_baselines3_example.py +++ b/examples/stable_baselines3_example.py @@ -28,42 +28,39 @@ default="logs/sb3", type=str, help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are " - "getting stored." + "getting stored.", ) parser.add_argument( "--experiment_name", default="experiment", type=str, help="The name of the experiment, which will be displayed in tensorboard and " - "for checkpoint directory and name (if enabled).", -) -parser.add_argument( - "--seed", - type=int, - default=0, - help="seed of the experiment" + "for checkpoint directory and name (if enabled).", ) +parser.add_argument("--seed", type=int, default=0, help="seed of the experiment") parser.add_argument( "--resume_model_path", default=None, type=str, help="The path to a model file previously saved using --save_model_path or a checkpoint saved using " - "--save_checkpoints_frequency. Use this to resume training or infer from a saved model.", + "--save_checkpoints_frequency. Use this to resume training or infer from a saved model.", ) parser.add_argument( "--save_model_path", default=None, type=str, help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later " - "to resume training. Extension will be set to .zip", + "to resume training. Extension will be set to .zip", ) parser.add_argument( "--save_checkpoint_frequency", default=None, type=int, - help=("If set, will save checkpoints every 'frequency' environment steps. " - "Requires a unique --experiment_name or --experiment_dir for each run. " - "Does not need --save_model_path to be set. "), + help=( + "If set, will save checkpoints every 'frequency' environment steps. " + "Requires a unique --experiment_name or --experiment_dir for each run. " + "Does not need --save_model_path to be set. " + ), ) parser.add_argument( "--onnx_export_path", @@ -76,34 +73,38 @@ default=1_000_000, type=int, help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, " - "it will continue training for this amount of steps from the saved state without counting previously trained " - "steps", + "it will continue training for this amount of steps from the saved state without counting previously trained " + "steps", ) parser.add_argument( "--inference", default=False, action="store_true", help="Instead of training, it will run inference on a loaded model for --timesteps steps. " - "Requires --resume_model_path to be set." + "Requires --resume_model_path to be set.", ) parser.add_argument( "--linear_lr_schedule", default=False, action="store_true", help="Use a linear LR schedule for training. If set, learning rate will decrease until it reaches 0 at " - "--timesteps" - "value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used." + "--timesteps" + "value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used.", ) parser.add_argument( "--viz", action="store_true", help="If set, the simulation will be displayed in a window during training. Otherwise " - "training will run without rendering the simulation. This setting does not apply to in-editor training.", - default=False + "training will run without rendering the simulation. This setting does not apply to in-editor training.", + default=False, ) parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env") -parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to " - "launch - requires --env_path to be set if > 1.") +parser.add_argument( + "--n_parallel", + default=1, + type=int, + help="How many instances of the environment executable to " "launch - requires --env_path to be set if > 1.", +) args, extras = parser.parse_known_args() @@ -136,10 +137,12 @@ def close_env(): # Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint): - raise RuntimeError(abs_path_checkpoint + " folder already exists. " - "Use a different --experiment_dir, or --experiment_name," - "or if previous checkpoints are not needed anymore, " - "remove the folder containing the checkpoints. ") + raise RuntimeError( + abs_path_checkpoint + " folder already exists. " + "Use a different --experiment_dir, or --experiment_name," + "or if previous checkpoints are not needed anymore, " + "remove the folder containing the checkpoints. " + ) if args.inference and args.resume_model_path is None: raise parser.error("Using --inference requires --resume_model_path to be set.") @@ -147,8 +150,9 @@ def close_env(): if args.env_path is None and args.viz: print("Info: Using --viz without --env_path set has no effect, in-editor training will always render.") -env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel, - speedup=args.speedup) +env = StableBaselinesGodotEnv( + env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel, speedup=args.speedup +) env = VecMonitor(env) @@ -177,13 +181,15 @@ def func(progress_remaining: float) -> float: if args.resume_model_path is None: learning_rate = 0.0003 if not args.linear_lr_schedule else linear_schedule(0.0003) - model: PPO = PPO("MultiInputPolicy", - env, - ent_coef=0.0001, - verbose=2, - n_steps=32, - tensorboard_log=args.experiment_dir, - learning_rate=learning_rate) + model: PPO = PPO( + "MultiInputPolicy", + env, + ent_coef=0.0001, + verbose=2, + n_steps=32, + tensorboard_log=args.experiment_dir, + learning_rate=learning_rate, + ) else: path_zip = pathlib.Path(args.resume_model_path) print("Loading model: " + os.path.abspath(path_zip)) @@ -201,13 +207,15 @@ def func(progress_remaining: float) -> float: checkpoint_callback = CheckpointCallback( save_freq=(args.save_checkpoint_frequency // env.num_envs), save_path=path_checkpoint, - name_prefix=args.experiment_name + name_prefix=args.experiment_name, ) - learn_arguments['callback'] = checkpoint_callback + learn_arguments["callback"] = checkpoint_callback try: model.learn(**learn_arguments) except KeyboardInterrupt: - print("Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used.") + print( + "Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used." + ) close_env() handle_onnx_export() diff --git a/examples/stable_baselines3_hp_tuning.py b/examples/stable_baselines3_hp_tuning.py index 7e280f8c..0b1abd4c 100644 --- a/examples/stable_baselines3_hp_tuning.py +++ b/examples/stable_baselines3_hp_tuning.py @@ -38,9 +38,16 @@ import argparse parser = argparse.ArgumentParser(allow_abbrev=False) -parser.add_argument("--env_path", default=None, type=str, help="The Godot binary to use, do not include for in editor training") +parser.add_argument( + "--env_path", default=None, type=str, help="The Godot binary to use, do not include for in editor training" +) parser.add_argument("--speedup", default=8, type=int, help="whether to speed up the physics in the env") -parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to launch - requires --env_path to be set if > 1.") +parser.add_argument( + "--n_parallel", + default=1, + type=int, + help="How many instances of the environment executable to launch - requires --env_path to be set if > 1.", +) args, extras = parser.parse_known_args() @@ -61,6 +68,7 @@ "ent_coef": 0.005, } + def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: """Sampler for PPO hyperparameters.""" learning_rate = trial.suggest_loguniform("learning_rate", 0.0003, 0.003) @@ -118,10 +126,19 @@ def objective(trial: optuna.Trial) -> float: print("args:", kwargs) # Create the RL model. training_port = GodotEnv.DEFAULT_PORT + 1 - model = PPO("MultiInputPolicy", VecMonitor(StableBaselinesGodotEnv(env_path=args.env_path, speedup=args.speedup, n_parallel=args.n_parallel, port=training_port)), tensorboard_log="logs/optuna", **kwargs) + model = PPO( + "MultiInputPolicy", + VecMonitor( + StableBaselinesGodotEnv( + env_path=args.env_path, speedup=args.speedup, n_parallel=args.n_parallel, port=training_port + ) + ), + tensorboard_log="logs/optuna", + **kwargs, + ) # Create env used for evaluation. eval_env = VecMonitor(StableBaselinesGodotEnv(env_path=args.env_path, speedup=args.speedup)) - + # Create the callback that will periodically evaluate and report the performance. eval_callback = TrialEvalCallback( eval_env, trial, n_eval_episodes=N_EVAL_EPISODES, eval_freq=EVAL_FREQ, deterministic=True @@ -142,12 +159,12 @@ def objective(trial: optuna.Trial) -> float: # Tell the optimizer that the trial failed. if nan_encountered: - #return 0 + # return 0 return float("nan") if eval_callback.is_pruned: raise optuna.exceptions.TrialPruned() - + return eval_callback.last_mean_reward @@ -178,4 +195,4 @@ def objective(trial: optuna.Trial) -> float: print(" User attrs:") for key, value in trial.user_attrs.items(): - print(" {}: {}".format(key, value)) \ No newline at end of file + print(" {}: {}".format(key, value)) From b9a804b6c6334121581704386def6f54b3b1e314 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:52:35 +0100 Subject: [PATCH 6/9] formats core codebase --- godot_rl/core/godot_env.py | 37 +++-- godot_rl/core/utils.py | 18 ++- godot_rl/custom_models/attention_model.py | 9 +- godot_rl/download_utils/download_examples.py | 24 ++-- .../download_utils/download_godot_editor.py | 34 ++--- godot_rl/download_utils/from_hub.py | 1 + godot_rl/main.py | 43 ++++-- godot_rl/wrappers/clean_rl_wrapper.py | 18 +-- .../wrappers/onnx/stable_baselines_export.py | 13 +- godot_rl/wrappers/ray_wrapper.py | 128 +++++++++--------- godot_rl/wrappers/sample_factory_wrapper.py | 19 +-- godot_rl/wrappers/sbg_single_obs_wrapper.py | 3 +- godot_rl/wrappers/stable_baselines_wrapper.py | 19 ++- 13 files changed, 197 insertions(+), 169 deletions(-) diff --git a/godot_rl/core/godot_env.py b/godot_rl/core/godot_env.py index 8c37ccec..7b6be772 100644 --- a/godot_rl/core/godot_env.py +++ b/godot_rl/core/godot_env.py @@ -10,10 +10,10 @@ from typing import Optional import numpy as np -from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path from gymnasium import spaces -from collections import OrderedDict +from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path + class GodotEnv: MAJOR_VERSION = "0" # Versioning for the environment @@ -22,15 +22,15 @@ class GodotEnv: DEFAULT_TIMEOUT = 60 # Default socket timeout TODO def __init__( - self, - env_path: str = None, - port: int = DEFAULT_PORT, - show_window: bool = False, - seed: int = 0, - framerate: Optional[int] = None, - action_repeat: Optional[int] = None, - speedup: Optional[int] = None, - convert_action_space: bool = False, + self, + env_path: str = None, + port: int = DEFAULT_PORT, + show_window: bool = False, + seed: int = 0, + framerate: Optional[int] = None, + action_repeat: Optional[int] = None, + speedup: Optional[int] = None, + convert_action_space: bool = False, ): """ Initialize a new instance of GodotEnv @@ -98,17 +98,17 @@ def check_platform(self, filename: str): if platform == "linux" or platform == "linux2": # Linux assert ( - pathlib.Path(filename).suffix == ".x86_64" + pathlib.Path(filename).suffix == ".x86_64" ), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .x86_64 file" elif platform == "darwin": # OSX assert ( - pathlib.Path(filename).suffix == ".app" + pathlib.Path(filename).suffix == ".app" ), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .app file" elif platform == "win32": # Windows... assert ( - pathlib.Path(filename).suffix == ".exe" + pathlib.Path(filename).suffix == ".exe" ), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .exe file" else: assert 0, f"unknown filetype {pathlib.Path(filename).suffix}" @@ -359,13 +359,10 @@ def _get_env_info(self): @staticmethod def _decode_2d_obs_from_string( - hex_string, - shape, + hex_string, + shape, ): - return ( - np.frombuffer(bytes.fromhex(hex_string), dtype=np.uint8) - .reshape(shape) - ) + return np.frombuffer(bytes.fromhex(hex_string), dtype=np.uint8).reshape(shape) def _send_as_json(self, dictionary): message_json = json.dumps(dictionary) diff --git a/godot_rl/core/utils.py b/godot_rl/core/utils.py index edca98d7..e313ef5a 100644 --- a/godot_rl/core/utils.py +++ b/godot_rl/core/utils.py @@ -5,7 +5,6 @@ import numpy as np - def lod_to_dol(lod): return {k: [dic[k] for dic in lod] for k in lod[0]} @@ -13,6 +12,7 @@ def lod_to_dol(lod): def dol_to_lod(dol): return [dict(zip(dol, t)) for t in zip(*dol.values())] + def convert_macos_path(env_path): """ On MacOs the user is supposed to provide a application.app file to env_path. @@ -23,12 +23,11 @@ def convert_macos_path(env_path): Example output: ./Demo.app/Contents/Macos/Demo """ - filenames = re.findall(r'[^\/]+(?=\.)', env_path) - assert ( - len(filenames) == 1 - ), f"An error occured while converting the env path for MacOS." + filenames = re.findall(r"[^\/]+(?=\.)", env_path) + assert len(filenames) == 1, f"An error occured while converting the env path for MacOS." return env_path + "/Contents/MacOS/" + filenames[0] + class ActionSpaceProcessor: # can convert tuple action dists to a single continuous action distribution # eg (Box(a), Box(b)) -> Box(a+b) @@ -36,7 +35,6 @@ class ActionSpaceProcessor: # etc # does not yet work with discrete dists of n>2 def __init__(self, action_space: gym.spaces.Tuple, convert) -> None: - self._original_action_space = action_space self._convert = convert @@ -46,7 +44,6 @@ def __init__(self, action_space: gym.spaces.Tuple, convert) -> None: use_multi_discrete_spaces = False multi_discrete_spaces = np.array([]) if isinstance(action_space, gym.spaces.Tuple): - if all(isinstance(space, gym.spaces.Discrete) for space in action_space.spaces): use_multi_discrete_spaces = True for space in action_space.spaces: @@ -58,7 +55,7 @@ def __init__(self, action_space: gym.spaces.Tuple, convert) -> None: space_size += space.shape[0] elif isinstance(space, gym.spaces.Discrete): if space.n > 2: - #for now only binary actions are supported if you mix different spaces + # for now only binary actions are supported if you mix different spaces # need to add support for the n>2 case raise NotImplementedError space_size += 1 @@ -96,7 +93,6 @@ def to_original_dist(self, action): counter += space.shape[0] elif isinstance(space, gym.spaces.Discrete): - discrete_actions = np.greater(action[:, counter], 0.0) discrete_actions = discrete_actions.astype(np.float32) original_action.append(discrete_actions) @@ -107,12 +103,14 @@ def to_original_dist(self, action): return original_action + def can_import(module_name): return not cant_import(module_name) + def cant_import(module_name): try: importlib.import_module(module_name) return False except ImportError: - return True \ No newline at end of file + return True diff --git a/godot_rl/custom_models/attention_model.py b/godot_rl/custom_models/attention_model.py index e5a101a2..ea3c47bb 100644 --- a/godot_rl/custom_models/attention_model.py +++ b/godot_rl/custom_models/attention_model.py @@ -4,11 +4,8 @@ import numpy as np from gym.spaces import Box, Discrete, MultiDiscrete from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet -from ray.rllib.models.torch.misc import (AppendBiasLayer, SlimFC, - normc_initializer) -from ray.rllib.models.torch.modules import (GRUGate, - RelativeMultiHeadAttention, - SkipConnection) +from ray.rllib.models.torch.misc import AppendBiasLayer, SlimFC, normc_initializer +from ray.rllib.models.torch.modules import GRUGate, RelativeMultiHeadAttention, SkipConnection from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement @@ -35,7 +32,6 @@ def __init__( model_config: ModelConfigDict, name: str, ): - TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) @@ -77,7 +73,6 @@ def forward( state: List[TensorType], seq_lens: TensorType, ) -> (TensorType, List[TensorType]): - observations = input_dict[SampleBatch.OBS] # print("unbatch", input_dict["obs"]["obs"].unbatch_all()[0]) # print(input_dict["obs"]) diff --git a/godot_rl/download_utils/download_examples.py b/godot_rl/download_utils/download_examples.py index aaa8d791..16491f36 100644 --- a/godot_rl/download_utils/download_examples.py +++ b/godot_rl/download_utils/download_examples.py @@ -3,35 +3,35 @@ import os import shutil from sys import platform -import wget from zipfile import ZipFile -BANCHES = {"4" : "main", - "3" : "godot3.5"} +import wget + +BANCHES = {"4": "main", "3": "godot3.5"} + +BASE_URL = "https://github.com/edbeeching/godot_rl_agents_examples" -BASE_URL="https://github.com/edbeeching/godot_rl_agents_examples" def download_examples(): - #select branch + # select branch print("Select Godot version:") for key in BANCHES.keys(): print(f"{key} : {BANCHES[key]}") - + branch = input("Enter your choice: ") - BRANCH = BANCHES[branch] + BRANCH = BANCHES[branch] os.makedirs("examples", exist_ok=True) - URL=f"{BASE_URL}/archive/refs/heads/{BRANCH}.zip" + URL = f"{BASE_URL}/archive/refs/heads/{BRANCH}.zip" print(f"downloading examples from {URL}") wget.download(URL, out="") print() print(f"unzipping") - with ZipFile(f"{BRANCH}.zip", 'r') as zipObj: - # Extract all the contents of zip file in different directory - zipObj.extractall('examples/') + with ZipFile(f"{BRANCH}.zip", "r") as zipObj: + # Extract all the contents of zip file in different directory + zipObj.extractall("examples/") print(f"cleaning up") os.remove(f"{BRANCH}.zip") print(f"moving files") for file in os.listdir(f"examples/godot_rl_agents_examples-{BRANCH}"): shutil.move(f"examples/godot_rl_agents_examples-{BRANCH}/{file}", "examples") os.rmdir(f"examples/godot_rl_agents_examples-{BRANCH}") - \ No newline at end of file diff --git a/godot_rl/download_utils/download_godot_editor.py b/godot_rl/download_utils/download_godot_editor.py index 170e6c8e..e6310c34 100644 --- a/godot_rl/download_utils/download_godot_editor.py +++ b/godot_rl/download_utils/download_godot_editor.py @@ -1,17 +1,16 @@ import os import shutil from sys import platform -import wget from zipfile import ZipFile -BASE_URL="https://downloads.tuxfamily.org/godotengine/" -VERSIONS = { - "3": "3.5.1", - "4": "4.0" -} +import wget + +BASE_URL = "https://downloads.tuxfamily.org/godotengine/" +VERSIONS = {"3": "3.5.1", "4": "4.0"} MOST_RECENT_VERSION = "rc5" + def get_version(): while True: version = input("Which Godot version do you want to download (3 or 4)? ") @@ -19,6 +18,7 @@ def get_version(): return version print("Invalid version. Please enter 3 or 4.") + def download_editor(): version = get_version() VERSION = VERSIONS[version] @@ -27,17 +27,17 @@ def download_editor(): if VERSION == "4.0": NEW_BASE_URL = f"{BASE_URL}{VERSION}/{MOST_RECENT_VERSION}/" NAME = MOST_RECENT_VERSION - LINUX_FILENAME=f"Godot_v{VERSION}-{NAME}_linux.x86_64.zip" + LINUX_FILENAME = f"Godot_v{VERSION}-{NAME}_linux.x86_64.zip" if VERSION == "4.0": - MAC_FILENAME=f"Godot_v{VERSION}-{NAME}_macos.universal.zip" + MAC_FILENAME = f"Godot_v{VERSION}-{NAME}_macos.universal.zip" else: - MAC_FILENAME=f"Godot_v{VERSION}-{NAME}_osx.universal.64.zip" - WINDOWS_FILENAME=f"Godot_v{VERSION}-{NAME}_win64.exe.zip" + MAC_FILENAME = f"Godot_v{VERSION}-{NAME}_osx.universal.64.zip" + WINDOWS_FILENAME = f"Godot_v{VERSION}-{NAME}_win64.exe.zip" os.makedirs("editor", exist_ok=True) - FILENAME="" + FILENAME = "" if platform == "linux" or platform == "linux2": - FILENAME = LINUX_FILENAME + FILENAME = LINUX_FILENAME elif platform == "darwin": FILENAME = MAC_FILENAME elif platform == "win32" or platform == "win64": @@ -45,14 +45,14 @@ def download_editor(): else: raise NotImplementedError - URL=f"{NEW_BASE_URL}{FILENAME}" + URL = f"{NEW_BASE_URL}{FILENAME}" print(f"downloading editor {FILENAME} for platform: {platform}") wget.download(URL, out="") print() print(f"unzipping") - with ZipFile(FILENAME, 'r') as zipObj: - # Extract all the contents of zip file in different directory - zipObj.extractall('editor/') + with ZipFile(FILENAME, "r") as zipObj: + # Extract all the contents of zip file in different directory + zipObj.extractall("editor/") print(f"cleaning up") - os.remove(FILENAME) \ No newline at end of file + os.remove(FILENAME) diff --git a/godot_rl/download_utils/from_hub.py b/godot_rl/download_utils/from_hub.py index b60a773b..51f49037 100644 --- a/godot_rl/download_utils/from_hub.py +++ b/godot_rl/download_utils/from_hub.py @@ -1,5 +1,6 @@ import argparse import os + from huggingface_hub import Repository diff --git a/godot_rl/main.py b/godot_rl/main.py index f856e673..56f94c05 100644 --- a/godot_rl/main.py +++ b/godot_rl/main.py @@ -25,23 +25,29 @@ try: from godot_rl.wrappers.ray_wrapper import rllib_training except ImportError as e: + def rllib_training(args, extras): - print("Import error when trying to use rllib. If you have not installed the package, try: pip install godot-rl[rllib]") + print( + "Import error when trying to use rllib. If you have not installed the package, try: pip install godot-rl[rllib]" + ) print("Otherwise try fixing the error above.") try: from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training except ImportError as e: + def stable_baselines_training(args, extras): print( "Import error when trying to use sb3. If you have not installed the package, try: pip install godot-rl[sb3]" ) print("Otherwise try fixing the error above.") + try: - from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy + from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training except ImportError as e: + def sample_factory_training(args, extras): print( "Import error when trying to use sample-factory If you have not installed the package, try: pip install godot-rl[sf]" @@ -51,26 +57,41 @@ def sample_factory_training(args, extras): def get_args(): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument("--trainer", default="sb3", choices=["sb3", "sf", "rllib"], type=str, help="framework to use (rllib, sf, sb3)") + parser.add_argument( + "--trainer", default="sb3", choices=["sb3", "sf", "rllib"], type=str, help="framework to use (rllib, sf, sb3)" + ) parser.add_argument("--env_path", default=None, type=str, help="Godot binary to use") - parser.add_argument("--config_file", default="ppo_test.yaml", type=str, help="The yaml config file [only for rllib]") + parser.add_argument( + "--config_file", default="ppo_test.yaml", type=str, help="The yaml config file [only for rllib]" + ) parser.add_argument("--restore", default=None, type=str, help="the location of a checkpoint to restore from") parser.add_argument("--eval", default=False, action="store_true", help="whether to eval the model") parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env") parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model") parser.add_argument("--num_gpus", default=None, type=int, help="Number of GPUs to use [only for rllib]") - parser.add_argument("--experiment_dir", default=None, type=str, help="The name of the the experiment directory, in which the tensorboard logs are getting stored") - parser.add_argument("--experiment_name", default="experiment", type=str, help="The name of the the experiment, which will be displayed in tensborboard") + parser.add_argument( + "--experiment_dir", + default=None, + type=str, + help="The name of the the experiment directory, in which the tensorboard logs are getting stored", + ) + parser.add_argument( + "--experiment_name", + default="experiment", + type=str, + help="The name of the the experiment, which will be displayed in tensborboard", + ) parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") parser.add_argument("--seed", default=0, type=int, help="seed of the experiment") - - args, extras = parser.parse_known_args() + + args, extras = parser.parse_known_args() if args.experiment_dir is None: args.experiment_dir = f"logs/{args.trainer}" - + if args.trainer == "sf" and args.env_path is None: - print("WARNING: the sample-factory intergration is not designed to run in interactive mode, please export you game to use this trainer") - + print( + "WARNING: the sample-factory intergration is not designed to run in interactive mode, please export you game to use this trainer" + ) return args, extras diff --git a/godot_rl/wrappers/clean_rl_wrapper.py b/godot_rl/wrappers/clean_rl_wrapper.py index 0c13fdfe..60ad2be7 100644 --- a/godot_rl/wrappers/clean_rl_wrapper.py +++ b/godot_rl/wrappers/clean_rl_wrapper.py @@ -1,15 +1,15 @@ -import numpy as np +from typing import Any, Dict, List, Optional, Tuple, Union + import gymnasium as gym +import numpy as np from numpy import ndarray -from godot_rl.core.utils import lod_to_dol from godot_rl.core.godot_env import GodotEnv -from typing import Any, Dict, List, Optional, Tuple, Union +from godot_rl.core.utils import lod_to_dol class CleanRLGodotEnv: def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: int = 0, **kwargs: object) -> None: - # If we are doing editor training, n_parallel must be 1 if env_path is None and n_parallel > 1: raise ValueError("You must provide the path to a exported game executable if n_parallel > 1") @@ -18,8 +18,10 @@ def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: in port = kwargs.pop("port", GodotEnv.DEFAULT_PORT) # Create a list of GodotEnv instances - self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port + p, seed=seed + p, **kwargs) for - p in range(n_parallel)] + self.envs = [ + GodotEnv(env_path=env_path, convert_action_space=True, port=port + p, seed=seed + p, **kwargs) + for p in range(n_parallel) + ] # Store the number of parallel environments self.n_parallel = n_parallel @@ -29,7 +31,7 @@ def _check_valid_action_space(self) -> None: action_space = self.envs[0].action_space if isinstance(action_space, gym.spaces.Tuple): assert ( - len(action_space.spaces) == 1 + len(action_space.spaces) == 1 ), f"sb3 supports a single action space, this env contains multiple spaces {action_space}" def step(self, action: np.ndarray) -> tuple[ndarray, list[Any], list[Any], list[Any], list[Any]]: @@ -45,7 +47,7 @@ def step(self, action: np.ndarray) -> tuple[ndarray, list[Any], list[Any], list[ # Send actions to each environment for i in range(self.n_parallel): - self.envs[i].step_send(action[i * num_envs:(i + 1) * num_envs]) + self.envs[i].step_send(action[i * num_envs : (i + 1) * num_envs]) # Receive results from each environment for i in range(self.n_parallel): diff --git a/godot_rl/wrappers/onnx/stable_baselines_export.py b/godot_rl/wrappers/onnx/stable_baselines_export.py index f39d0b32..19f679f9 100644 --- a/godot_rl/wrappers/onnx/stable_baselines_export.py +++ b/godot_rl/wrappers/onnx/stable_baselines_export.py @@ -41,11 +41,12 @@ def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str): opset_version=9, input_names=["obs", "state_ins"], output_names=["output", "state_outs"], - dynamic_axes={'obs' : {0 : 'batch_size'}, - 'state_ins' : {0 : 'batch_size'}, # variable length axes - 'output' : {0 : 'batch_size'}, - 'state_outs' : {0 : 'batch_size'}} - + dynamic_axes={ + "obs": {0: "batch_size"}, + "state_ins": {0: "batch_size"}, # variable length axes + "output": {0: "batch_size"}, + "state_outs": {0: "batch_size"}, + }, ) verify_onnx_export(ppo, onnx_model_path) @@ -59,7 +60,7 @@ def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10): onnx.checker.check_model(onnx_model) sb3_model = ppo.policy.to("cpu") - ort_sess = ort.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) + ort_sess = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) for i in range(num_tests): obs = dict(ppo.observation_space.sample()) diff --git a/godot_rl/wrappers/ray_wrapper.py b/godot_rl/wrappers/ray_wrapper.py index e163fdd4..38219ac7 100644 --- a/godot_rl/wrappers/ray_wrapper.py +++ b/godot_rl/wrappers/ray_wrapper.py @@ -25,7 +25,6 @@ def __init__( timeout_wait=60, config=None, ) -> None: - self._env = GodotEnv( env_path=env_path, port=port, @@ -33,7 +32,7 @@ def __init__( show_window=show_window, framerate=framerate, action_repeat=action_repeat, - speedup=speedup + speedup=speedup, ) super().__init__( observation_space=self._env.observation_space, @@ -41,9 +40,11 @@ def __init__( num_envs=self._env.num_envs, ) - def vector_reset(self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None) -> List[EnvObsType]: + def vector_reset( + self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None + ) -> List[EnvObsType]: self.obs, info = self._env.reset() - return self.obs, info + return self.obs, info def vector_step( self, actions: List[EnvActionType] @@ -55,12 +56,13 @@ def vector_step( def get_unwrapped(self): return [self._env] - def reset_at(self, - index: Optional[int] = None, - *, - seed: Optional[int] = None, - options: Optional[dict] = None, - ) -> EnvObsType: + def reset_at( + self, + index: Optional[int] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> EnvObsType: # the env is reset automatically, no need to reset it return self.obs[index], {} @@ -82,43 +84,45 @@ def register_env(): def rllib_export(model_path): - #get path from the config file and remove the file name - path = model_path #full path with file name - path = path.split("/") #split the path into a list - path = path[:-1] #remove the file name from the list - #duplicate the path for the export - export_path = path.copy() - export_path.append("onnx") - export_path = "/".join(export_path) #join the list into a string - #duplicate the last element of the list - path.append(path[-1]) - #change format from checkpoint_000500 to checkpoint-500 - temp = path[-1].split("_") - temp = temp[-1] - #parse the number - temp = int(temp) - #back to string - temp = str(temp) - #join the string with the new format - path[-1] = "checkpoint-" + temp - path = "/".join(path) #join the list into a string - #best_checkpoint = results.get_best_checkpoint(results.trials[0], mode="max") - #print(f".. best checkpoint was: {best_checkpoint}") - - #From here on, the relevant part to exporting the model - new_trainer = PPOTrainer(config=exp["config"]) - new_trainer.restore(path) - #policy = new_trainer.get_policy() - new_trainer.export_policy_model(export_dir=export_path, onnx = 9) #This works for version 1.11.X - #Running with: gdrl --env_path envs/builds/JumperHard/jumper_hard.exe --export --restore envs/checkpoints/jumper_hard/checkpoint_000500/checkpoint-500 - #model = policy.model - #export the model to onnx using torch.onnx.export - #dummy_input = torch.randn(1, 3, 84, 84) - #input is dictionary with key "obs" and value is a tensor of shape [...,8] - #tensor = torch.randn([1, 2, 4, 6, 8, 10, 12, 14]) - #dummy_input = {"obs": tensor} - #torch.onnx.export(model, dummy_input, "model.onnx", verbose=True, - #dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}) + # get path from the config file and remove the file name + path = model_path # full path with file name + path = path.split("/") # split the path into a list + path = path[:-1] # remove the file name from the list + # duplicate the path for the export + export_path = path.copy() + export_path.append("onnx") + export_path = "/".join(export_path) # join the list into a string + # duplicate the last element of the list + path.append(path[-1]) + # change format from checkpoint_000500 to checkpoint-500 + temp = path[-1].split("_") + temp = temp[-1] + # parse the number + temp = int(temp) + # back to string + temp = str(temp) + # join the string with the new format + path[-1] = "checkpoint-" + temp + path = "/".join(path) # join the list into a string + # best_checkpoint = results.get_best_checkpoint(results.trials[0], mode="max") + # print(f".. best checkpoint was: {best_checkpoint}") + + # From here on, the relevant part to exporting the model + new_trainer = PPOTrainer(config=exp["config"]) + new_trainer.restore(path) + # policy = new_trainer.get_policy() + new_trainer.export_policy_model(export_dir=export_path, onnx=9) # This works for version 1.11.X + + +# Running with: gdrl --env_path envs/builds/JumperHard/jumper_hard.exe --export --restore envs/checkpoints/jumper_hard/checkpoint_000500/checkpoint-500 +# model = policy.model +# export the model to onnx using torch.onnx.export +# dummy_input = torch.randn(1, 3, 84, 84) +# input is dictionary with key "obs" and value is a tensor of shape [...,8] +# tensor = torch.randn([1, 2, 4, 6, 8, 10, 12, 14]) +# dummy_input = {"obs": tensor} +# torch.onnx.export(model, dummy_input, "model.onnx", verbose=True, +# dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}) def rllib_training(args, extras): @@ -144,10 +148,10 @@ def rllib_training(args, extras): checkpoint_freq = 10 checkpoint_at_end = True - + exp["config"]["env_config"]["show_window"] = args.viz exp["config"]["env_config"]["speedup"] = args.speedup - + if args.eval or args.export: checkpoint_freq = 0 exp["config"]["env_config"]["show_window"] = True @@ -167,20 +171,20 @@ def rllib_training(args, extras): if not args.export: results = tune.run( - exp["algorithm"], - name=run_name, - config=exp["config"], - stop=exp["stop"], - verbose=3, - checkpoint_freq=checkpoint_freq, - checkpoint_at_end=not args.eval, - restore=args.restore, - local_dir=os.path.abspath(args.experiment_dir) or os.path.abspath("logs/rllib"), - trial_name_creator=lambda trial: f"{args.experiment_name}" if args.experiment_name else f"{trial.trainable_name}_{trial.trial_id}" - ) + exp["algorithm"], + name=run_name, + config=exp["config"], + stop=exp["stop"], + verbose=3, + checkpoint_freq=checkpoint_freq, + checkpoint_at_end=not args.eval, + restore=args.restore, + local_dir=os.path.abspath(args.experiment_dir) or os.path.abspath("logs/rllib"), + trial_name_creator=lambda trial: f"{args.experiment_name}" + if args.experiment_name + else f"{trial.trainable_name}_{trial.trial_id}", + ) if args.export: rllib_export(args.restore) ray.shutdown() - - diff --git a/godot_rl/wrappers/sample_factory_wrapper.py b/godot_rl/wrappers/sample_factory_wrapper.py index c38f8b44..ea567e45 100644 --- a/godot_rl/wrappers/sample_factory_wrapper.py +++ b/godot_rl/wrappers/sample_factory_wrapper.py @@ -1,15 +1,17 @@ import argparse -from functools import partial import random +from functools import partial + import numpy as np +from gymnasium import Env from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args +from sample_factory.enjoy import enjoy from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl -from sample_factory.enjoy import enjoy from godot_rl.core.godot_env import GodotEnv from godot_rl.core.utils import lod_to_dol -from gymnasium import Env + class SampleFactoryEnvWrapperBatched(GodotEnv, Env): @property @@ -32,7 +34,6 @@ def step(self, action): @staticmethod def to_numpy(lod): - for d in lod: for k, v in d.items(): d[k] = np.array(v) @@ -51,6 +52,7 @@ def unwrapped(self): @property def num_agents(self): return self.num_envs + def reset(self, seed=None, options=None): obs, info = super().reset(seed=seed) return self.to_numpy(obs), info @@ -61,7 +63,6 @@ def step(self, action): @staticmethod def to_numpy(lod): - for d in lod: for k, v in d.items(): d[k] = np.array(v) @@ -72,7 +73,9 @@ def render(): return -def make_godot_env_func(env_path, full_env_name, cfg=None, env_config=None, render_mode=None, seed=0, speedup=1, viz=False): +def make_godot_env_func( + env_path, full_env_name, cfg=None, env_config=None, render_mode=None, seed=0, speedup=1, viz=False +): port = cfg.base_port print("BASE PORT ", cfg.base_port) show_window = False @@ -168,7 +171,7 @@ def parse_gdrl_args(args, argv=None, evaluation=False): add_gdrl_env_args(partial_cfg.env, parser, evaluation=evaluation) gdrl_override_defaults(partial_cfg.env, parser) final_cfg = parse_full_cfg(parser, argv) - + final_cfg.train_dir = args.experiment_dir or "logs/sf" final_cfg.experiment = args.experiment_name or final_cfg.experiment return final_cfg @@ -177,7 +180,7 @@ def parse_gdrl_args(args, argv=None, evaluation=False): def sample_factory_training(args, extras): register_gdrl_env(args) cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval) - #cfg.base_port = random.randint(20000, 22000) + # cfg.base_port = random.randint(20000, 22000) status = run_rl(cfg) return status diff --git a/godot_rl/wrappers/sbg_single_obs_wrapper.py b/godot_rl/wrappers/sbg_single_obs_wrapper.py index f4840cd1..560c01ef 100644 --- a/godot_rl/wrappers/sbg_single_obs_wrapper.py +++ b/godot_rl/wrappers/sbg_single_obs_wrapper.py @@ -2,12 +2,13 @@ import gymnasium as gym import numpy as np -from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv +from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv # A variant of the Stable Baselines Godot Env that only supports a single obs space from the dictionary - obs["obs"] by default. # This provides some basic support for using envs that have a single obs space with policies other than MultiInputPolicy. + class SBGSingleObsEnv(StableBaselinesGodotEnv): def __init__(self, obs_key="obs", *args, **kwargs) -> None: self.obs_key = obs_key diff --git a/godot_rl/wrappers/stable_baselines_wrapper.py b/godot_rl/wrappers/stable_baselines_wrapper.py index fb723e3d..bccf1434 100644 --- a/godot_rl/wrappers/stable_baselines_wrapper.py +++ b/godot_rl/wrappers/stable_baselines_wrapper.py @@ -1,9 +1,10 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import gymnasium as gym import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.vec_env.base_vec_env import VecEnv from stable_baselines3.common.vec_env.vec_monitor import VecMonitor -from typing import Any, Dict, List, Optional, Tuple, Union from godot_rl.core.godot_env import GodotEnv from godot_rl.core.utils import can_import, lod_to_dol @@ -14,13 +15,16 @@ def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: in # If we are doing editor training, n_parallel must be 1 if env_path is None and n_parallel > 1: raise ValueError("You must provide the path to a exported game executable if n_parallel > 1") - + # Define the default port port = kwargs.pop("port", GodotEnv.DEFAULT_PORT) # Create a list of GodotEnv instances - self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port+p, seed=seed+p, **kwargs) for p in range(n_parallel)] - + self.envs = [ + GodotEnv(env_path=env_path, convert_action_space=True, port=port + p, seed=seed + p, **kwargs) + for p in range(n_parallel) + ] + # Store the number of parallel environments self.n_parallel = n_parallel @@ -51,7 +55,7 @@ def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], np.ndarray, n # Send actions to each environment for i in range(self.n_parallel): - self.envs[i].step_send(action[i*num_envs:(i+1)*num_envs]) + self.envs[i].step_send(action[i * num_envs : (i + 1) * num_envs]) # Receive results from each environment for i in range(self.n_parallel): @@ -109,12 +113,12 @@ def env_is_wrapped(self, wrapper_class: type, indices: Optional[List[int]] = Non def env_method(self): raise NotImplementedError() - def get_attr(self, attr_name: str, indices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices=None) -> List[Any]: if attr_name == "render_mode": return [None for _ in range(self.num_envs)] raise AttributeError("get attr not fully implemented in godot-rl StableBaselinesWrapper") - def seed(self, seed = None): + def seed(self, seed=None): raise NotImplementedError() def set_attr(self): @@ -128,6 +132,7 @@ def step_wait(self) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List # Wait for the results from the asynchronous step return self.results + def stable_baselines_training(args, extras, n_steps: int = 200000, **kwargs) -> None: if can_import("ray"): print("WARNING, stable baselines and ray[rllib] are not compatable") From 374eafe0fe87348e656b863268e63c49cd57fd94 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 21:52:47 +0100 Subject: [PATCH 7/9] makefile and deps --- Makefile | 12 ++++++------ setup.cfg | 6 +++--- tests/benchmark_env.py | 2 -- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index d5cf222d..42e1f664 100644 --- a/Makefile +++ b/Makefile @@ -1,16 +1,16 @@ .PHONY: quality style test unity-test -check_dirs := src tests godot_rl +check_dirs := tests godot_rl # Format source code automatically style: - python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py - python -m isort $(check_dirs) setup.py + black --line-length 120 --target-version py310 tests godot_rl + isort -w 120 tests godot_rl # Check that source code meets quality standards quality: - python -m black --check --line-length 119 --target-version py310 $(check_dirs) setup.py - python -m isort --check-only $(check_dirs) setup.py - python -m flake8 --max-line-length 119 $(check_dirs) setup.py + black --check --line-length 120 --target-version py310 tests godot_rl + isort -w 120 --check-only tests godot_rl + flake8 --max-line-length 120 tests godot_rl # Run tests for the library test: diff --git a/setup.cfg b/setup.cfg index 319bb957..985fcd20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,9 +38,9 @@ test = dev = pytest>=6.0 pytest-xdist - black[jupyter]~=22.0 - flake8>=3.8.3 - isort>=5.0.0 + black + flake8 + isort pyyaml>=5.3.1 sf = diff --git a/tests/benchmark_env.py b/tests/benchmark_env.py index 0043708a..d5cc4f6a 100644 --- a/tests/benchmark_env.py +++ b/tests/benchmark_env.py @@ -31,7 +31,6 @@ results = {} for framerate, port in zip(framerates, ports): - env = GodotEnv( env_path=env_path, port=port, @@ -44,7 +43,6 @@ action_space = env.action_space start = time.time() for i in range(N_STEPS): - actions = [action_space.sample() for _ in range(n_envs)] _ = env.step(actions) From 39f677a7b71c019bafb78f69701327997a6a235d Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 22:02:54 +0100 Subject: [PATCH 8/9] removes dead custom models --- godot_rl/custom_models/__init__.py | 0 godot_rl/custom_models/attention_model.py | 96 ----------------------- 2 files changed, 96 deletions(-) delete mode 100644 godot_rl/custom_models/__init__.py delete mode 100644 godot_rl/custom_models/attention_model.py diff --git a/godot_rl/custom_models/__init__.py b/godot_rl/custom_models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/godot_rl/custom_models/attention_model.py b/godot_rl/custom_models/attention_model.py deleted file mode 100644 index ea3c47bb..00000000 --- a/godot_rl/custom_models/attention_model.py +++ /dev/null @@ -1,96 +0,0 @@ -import logging - -import gym -import numpy as np -from gym.spaces import Box, Discrete, MultiDiscrete -from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet -from ray.rllib.models.torch.misc import AppendBiasLayer, SlimFC, normc_initializer -from ray.rllib.models.torch.modules import GRUGate, RelativeMultiHeadAttention, SkipConnection -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import Dict, List, ModelConfigDict, TensorType - -torch, nn = try_import_torch() -logger = logging.getLogger(__name__) - - -# defines the attention model used in the bullet hell environment -# first a feed forward to test that observations are being handled correctly - - -class MyAttentionModel(TorchModelV2, nn.Module): - """Generic fully connected network.""" - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - ): - TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) - - nn.Module.__init__(self) - # simple baseline, fc all inputs and sum then value and policy head - - # if isinstance(action_space, Discrete): - # self.action_dim = action_space.n - # elif isinstance(action_space, MultiDiscrete): - # self.action_dim = np.product(action_space.nvec) - # elif action_space.shape is not None: - # self.action_dim = int(np.product(action_space.shape)) - # else: - # self.action_dim = int(len(action_space)) - # print("action space", action_space, self.action_dim, num_outputs) - prev_layer_size = 3 # int(np.product(obs_space.shape)) - # obs_space["obs"]["max_length"] = 1 - self.model = TorchFCNet(obs_space, action_space, num_outputs, model_config, name) - print(self.model) - - print(obs_space, prev_layer_size, self.num_outputs) - self._logits_branch = SlimFC( - in_size=prev_layer_size, - out_size=self.num_outputs, - activation_fn=None, - initializer=torch.nn.init.xavier_uniform_, - ) - self._value_branch = SlimFC( - in_size=prev_layer_size, - out_size=1, - activation_fn=None, - initializer=torch.nn.init.xavier_uniform_, - ) - # torch.set_printoptions(profile="full") - - @override(TorchModelV2) - def forward( - self, - input_dict: Dict[str, TensorType], - state: List[TensorType], - seq_lens: TensorType, - ) -> (TensorType, List[TensorType]): - observations = input_dict[SampleBatch.OBS] - # print("unbatch", input_dict["obs"]["obs"].unbatch_all()[0]) - # print(input_dict["obs"]) - self._debug_batch_size = len(input_dict["obs"]["obs"].unbatch_all()) - if not input_dict["obs"]["obs"].unbatch_all()[0]: - return ( - np.zeros((self._debug_batch_size, 4)), - [], - ) - - results = [] - for obs in input_dict["obs"]["obs"].unbatch_all(): - batch = torch.cat(obs) - out = self.model({"obs": batch}) - print(out.size()) - - return np.zeros((self._debug_batch_size, 4)), state - - @override(TorchModelV2) - def value_function(self) -> TensorType: - return torch.zeros(self._debug_batch_size) From ed8407f1323f9eabff7beef8a2ead5c923d1d6a3 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 23 Jan 2024 22:23:38 +0100 Subject: [PATCH 9/9] more code quality fixes --- Makefile | 12 ++- examples/clean_rl_example.py | 4 +- examples/sample_factory_example.py | 3 +- examples/stable_baselines3_example.py | 10 ++- examples/stable_baselines3_hp_tuning.py | 20 ++--- godot_rl/core/godot_env.py | 12 +-- godot_rl/core/utils.py | 2 +- godot_rl/download_utils/download_examples.py | 15 ++-- .../download_utils/download_godot_editor.py | 5 +- godot_rl/download_utils/from_hub.py | 2 +- godot_rl/main.py | 27 +++---- godot_rl/wrappers/clean_rl_wrapper.py | 2 +- godot_rl/wrappers/ray_wrapper.py | 76 ++++++++++--------- godot_rl/wrappers/sample_factory_wrapper.py | 1 - godot_rl/wrappers/sbg_single_obs_wrapper.py | 7 +- godot_rl/wrappers/stable_baselines_wrapper.py | 2 +- setup.cfg | 2 + tests/test_action_space_preprocessor.py | 3 +- tests/test_call_method.py | 2 - tests/test_godot_env.py | 2 - 20 files changed, 102 insertions(+), 107 deletions(-) diff --git a/Makefile b/Makefile index 42e1f664..b939b2aa 100644 --- a/Makefile +++ b/Makefile @@ -1,16 +1,14 @@ .PHONY: quality style test unity-test -check_dirs := tests godot_rl - # Format source code automatically style: - black --line-length 120 --target-version py310 tests godot_rl - isort -w 120 tests godot_rl + black --line-length 120 --target-version py310 tests godot_rl examples + isort -w 120 tests godot_rl examples # Check that source code meets quality standards quality: - black --check --line-length 120 --target-version py310 tests godot_rl - isort -w 120 --check-only tests godot_rl - flake8 --max-line-length 120 tests godot_rl + black --check --line-length 120 --target-version py310 tests godot_rl examples + isort -w 120 --check-only tests godot_rl examples + flake8 --max-line-length 120 tests godot_rl examples # Run tests for the library test: diff --git a/examples/clean_rl_example.py b/examples/clean_rl_example.py index 453f99ac..b95c2ed0 100644 --- a/examples/clean_rl_example.py +++ b/examples/clean_rl_example.py @@ -4,14 +4,16 @@ import pathlib import random import time -from distutils.util import strtobool from collections import deque +from distutils.util import strtobool + import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.distributions.normal import Normal from torch.utils.tensorboard import SummaryWriter + from godot_rl.wrappers.clean_rl_wrapper import CleanRLGodotEnv diff --git a/examples/sample_factory_example.py b/examples/sample_factory_example.py index 9ce243cf..bd86aaaf 100644 --- a/examples/sample_factory_example.py +++ b/examples/sample_factory_example.py @@ -1,5 +1,6 @@ import argparse -from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy + +from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training def get_args(): diff --git a/examples/stable_baselines3_example.py b/examples/stable_baselines3_example.py index 3d4e90d6..b19b0a9d 100644 --- a/examples/stable_baselines3_example.py +++ b/examples/stable_baselines3_example.py @@ -3,12 +3,13 @@ import pathlib from typing import Callable +from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.vec_env.vec_monitor import VecMonitor + from godot_rl.core.utils import can_import -from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx -from stable_baselines3 import PPO -from stable_baselines3.common.vec_env.vec_monitor import VecMonitor +from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv # To download the env source and binary: # 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase @@ -214,7 +215,8 @@ def func(progress_remaining: float) -> float: model.learn(**learn_arguments) except KeyboardInterrupt: print( - "Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used." + """Training interrupted by user. Will save if --save_model_path was + used and/or export if --onnx_export_path was used.""" ) close_env() diff --git a/examples/stable_baselines3_hp_tuning.py b/examples/stable_baselines3_hp_tuning.py index 0b1abd4c..5e8754bb 100644 --- a/examples/stable_baselines3_hp_tuning.py +++ b/examples/stable_baselines3_hp_tuning.py @@ -8,8 +8,8 @@ You can run this example as follows: $ python examples/stable_baselines3_hp_tuning.py --env_path= --speedup=8 --n_parallel=1 - -Feel free to copy this script and update, add or remove the hp values to your liking. + +Feel free to copy this script and update, add or remove the hp values to your liking. """ try: @@ -17,25 +17,21 @@ from optuna.pruners import MedianPruner from optuna.samplers import TPESampler except ImportError as e: + print(e) print("You need to install optuna to use the hyperparameter tuning script. Try: pip install optuna") exit() -from typing import Any -from typing import Dict +import argparse +from typing import Any, Dict import gymnasium as gym - -from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv -from godot_rl.core.godot_env import GodotEnv - +import torch from stable_baselines3 import PPO from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.vec_env.vec_monitor import VecMonitor -import torch -import torch.nn as nn - -import argparse +from godot_rl.core.godot_env import GodotEnv +from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument( diff --git a/godot_rl/core/godot_env.py b/godot_rl/core/godot_env.py index 7b6be772..9f0d1c83 100644 --- a/godot_rl/core/godot_env.py +++ b/godot_rl/core/godot_env.py @@ -99,17 +99,17 @@ def check_platform(self, filename: str): # Linux assert ( pathlib.Path(filename).suffix == ".x86_64" - ), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .x86_64 file" + ), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .x86_64 file" elif platform == "darwin": # OSX assert ( pathlib.Path(filename).suffix == ".app" - ), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .app file" + ), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .app file" elif platform == "win32": # Windows... assert ( pathlib.Path(filename).suffix == ".exe" - ), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .exe file" + ), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .exe file" else: assert 0, f"unknown filetype {pathlib.Path(filename).suffix}" @@ -132,7 +132,7 @@ def from_numpy(self, action, order_ij=False): env_action = {} for j, k in enumerate(self._action_space.keys()): - if order_ij == True: + if order_ij is True: v = action[i][j] else: v = action[j][i] @@ -263,7 +263,7 @@ def _launch_env(self, env_path, port, show_window, framerate, seed, action_repea launch_cmd = f"{path} --port={port} --env_seed={seed}" - if show_window == False: + if show_window is False: launch_cmd += " --disable-render-loop --headless" if framerate is not None: launch_cmd += f" --fixed-fps {framerate}" @@ -382,7 +382,7 @@ def _clear_socket(self): data = self.connection.recv(4) if not data: break - except BlockingIOError as e: + except BlockingIOError: pass self.connection.setblocking(True) diff --git a/godot_rl/core/utils.py b/godot_rl/core/utils.py index e313ef5a..ed8d5a1f 100644 --- a/godot_rl/core/utils.py +++ b/godot_rl/core/utils.py @@ -24,7 +24,7 @@ def convert_macos_path(env_path): """ filenames = re.findall(r"[^\/]+(?=\.)", env_path) - assert len(filenames) == 1, f"An error occured while converting the env path for MacOS." + assert len(filenames) == 1, "An error occured while converting the env path for MacOS." return env_path + "/Contents/MacOS/" + filenames[0] diff --git a/godot_rl/download_utils/download_examples.py b/godot_rl/download_utils/download_examples.py index 16491f36..09cd5c94 100644 --- a/godot_rl/download_utils/download_examples.py +++ b/godot_rl/download_utils/download_examples.py @@ -2,12 +2,11 @@ import os import shutil -from sys import platform from zipfile import ZipFile import wget -BANCHES = {"4": "main", "3": "godot3.5"} +BRANCHES = {"4": "main", "3": "godot3.5"} BASE_URL = "https://github.com/edbeeching/godot_rl_agents_examples" @@ -15,23 +14,23 @@ def download_examples(): # select branch print("Select Godot version:") - for key in BANCHES.keys(): - print(f"{key} : {BANCHES[key]}") + for key in BRANCHES.keys(): + print(f"{key} : {BRANCHES[key]}") branch = input("Enter your choice: ") - BRANCH = BANCHES[branch] + BRANCH = BRANCHES[branch] os.makedirs("examples", exist_ok=True) URL = f"{BASE_URL}/archive/refs/heads/{BRANCH}.zip" print(f"downloading examples from {URL}") wget.download(URL, out="") print() - print(f"unzipping") + print("unzipping") with ZipFile(f"{BRANCH}.zip", "r") as zipObj: # Extract all the contents of zip file in different directory zipObj.extractall("examples/") - print(f"cleaning up") + print("cleaning up") os.remove(f"{BRANCH}.zip") - print(f"moving files") + print("moving files") for file in os.listdir(f"examples/godot_rl_agents_examples-{BRANCH}"): shutil.move(f"examples/godot_rl_agents_examples-{BRANCH}/{file}", "examples") os.rmdir(f"examples/godot_rl_agents_examples-{BRANCH}") diff --git a/godot_rl/download_utils/download_godot_editor.py b/godot_rl/download_utils/download_godot_editor.py index e6310c34..c0bd0ae9 100644 --- a/godot_rl/download_utils/download_godot_editor.py +++ b/godot_rl/download_utils/download_godot_editor.py @@ -1,5 +1,4 @@ import os -import shutil from sys import platform from zipfile import ZipFile @@ -50,9 +49,9 @@ def download_editor(): print(f"downloading editor {FILENAME} for platform: {platform}") wget.download(URL, out="") print() - print(f"unzipping") + print("unzipping") with ZipFile(FILENAME, "r") as zipObj: # Extract all the contents of zip file in different directory zipObj.extractall("editor/") - print(f"cleaning up") + print("cleaning up") os.remove(FILENAME) diff --git a/godot_rl/download_utils/from_hub.py b/godot_rl/download_utils/from_hub.py index 51f49037..c0bb36d3 100644 --- a/godot_rl/download_utils/from_hub.py +++ b/godot_rl/download_utils/from_hub.py @@ -18,7 +18,7 @@ def main(): parser.add_argument( "-r", "--hf_repository", - help="Repo id of the dataset / environment repository from the Hugging Face Hub in the form user_name/repo_name", + help="Repo id of the dataset / environment repo from the Hugging Face Hub in the form user_name/repo_name", type=str, ) parser.add_argument( diff --git a/godot_rl/main.py b/godot_rl/main.py index 56f94c05..3de3027d 100644 --- a/godot_rl/main.py +++ b/godot_rl/main.py @@ -1,14 +1,14 @@ """ This is the main entrypoint to the Godot RL Agents interface -Example usage is best found in the documentation: +Example usage is best found in the documentation: https://github.com/edbeeching/godot_rl_agents/blob/main/docs/EXAMPLE_ENVIRONMENTS.md Hyperparameters and training algorithm can be defined in a .yaml file, see ppo_test.yaml as an example. Interactive Training: -With the Godot editor open, type gdrl in the terminal to launch training and +With the Godot editor open, type gdrl in the terminal to launch training and then press PLAY in the Godot editor. Training can be stopped with CTRL+C or by pressing STOP in the editor. @@ -25,34 +25,33 @@ try: from godot_rl.wrappers.ray_wrapper import rllib_training except ImportError as e: + error_message = str(e) def rllib_training(args, extras): - print( - "Import error when trying to use rllib. If you have not installed the package, try: pip install godot-rl[rllib]" - ) - print("Otherwise try fixing the error above.") + print("Import error importing rllib. If you have not installed the package, try: pip install godot-rl[rllib]") + print("Otherwise try fixing the error.", error_message) try: from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training except ImportError as e: + error_message = str(e) def stable_baselines_training(args, extras): - print( - "Import error when trying to use sb3. If you have not installed the package, try: pip install godot-rl[sb3]" - ) - print("Otherwise try fixing the error above.") + print("Import error importing sb3. If you have not installed the package, try: pip install godot-rl[sb3]") + print("Otherwise try fixing the error.", error_message) try: from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training except ImportError as e: + error_message = str(e) def sample_factory_training(args, extras): print( - "Import error when trying to use sample-factory If you have not installed the package, try: pip install godot-rl[sf]" + "Import error importing sample-factory If you have not installed the package, try: pip install godot-rl[sf]" ) - print("Otherwise try fixing the error above.") + print("Otherwise try fixing the error.", error_message) def get_args(): @@ -89,9 +88,7 @@ def get_args(): args.experiment_dir = f"logs/{args.trainer}" if args.trainer == "sf" and args.env_path is None: - print( - "WARNING: the sample-factory intergration is not designed to run in interactive mode, please export you game to use this trainer" - ) + print("WARNING: the sample-factory intergration is not designed to run in interactive mode, export you game") return args, extras diff --git a/godot_rl/wrappers/clean_rl_wrapper.py b/godot_rl/wrappers/clean_rl_wrapper.py index 60ad2be7..0059dfc2 100644 --- a/godot_rl/wrappers/clean_rl_wrapper.py +++ b/godot_rl/wrappers/clean_rl_wrapper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional import gymnasium as gym import numpy as np diff --git a/godot_rl/wrappers/ray_wrapper.py b/godot_rl/wrappers/ray_wrapper.py index 38219ac7..d1fc68a1 100644 --- a/godot_rl/wrappers/ray_wrapper.py +++ b/godot_rl/wrappers/ray_wrapper.py @@ -1,6 +1,6 @@ import os import pathlib -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np import ray @@ -83,38 +83,40 @@ def register_env(): ) -def rllib_export(model_path): - # get path from the config file and remove the file name - path = model_path # full path with file name - path = path.split("/") # split the path into a list - path = path[:-1] # remove the file name from the list - # duplicate the path for the export - export_path = path.copy() - export_path.append("onnx") - export_path = "/".join(export_path) # join the list into a string - # duplicate the last element of the list - path.append(path[-1]) - # change format from checkpoint_000500 to checkpoint-500 - temp = path[-1].split("_") - temp = temp[-1] - # parse the number - temp = int(temp) - # back to string - temp = str(temp) - # join the string with the new format - path[-1] = "checkpoint-" + temp - path = "/".join(path) # join the list into a string - # best_checkpoint = results.get_best_checkpoint(results.trials[0], mode="max") - # print(f".. best checkpoint was: {best_checkpoint}") - - # From here on, the relevant part to exporting the model - new_trainer = PPOTrainer(config=exp["config"]) - new_trainer.restore(path) - # policy = new_trainer.get_policy() - new_trainer.export_policy_model(export_dir=export_path, onnx=9) # This works for version 1.11.X - - -# Running with: gdrl --env_path envs/builds/JumperHard/jumper_hard.exe --export --restore envs/checkpoints/jumper_hard/checkpoint_000500/checkpoint-500 +# TODO: fix this implementation +# def rllib_export(model_path): +# # get path from the config file and remove the file name +# path = model_path # full path with file name +# path = path.split("/") # split the path into a list +# path = path[:-1] # remove the file name from the list +# # duplicate the path for the export +# export_path = path.copy() +# export_path.append("onnx") +# export_path = "/".join(export_path) # join the list into a string +# # duplicate the last element of the list +# path.append(path[-1]) +# # change format from checkpoint_000500 to checkpoint-500 +# temp = path[-1].split("_") +# temp = temp[-1] +# # parse the number +# temp = int(temp) +# # back to string +# temp = str(temp) +# # join the string with the new format +# path[-1] = "checkpoint-" + temp +# path = "/".join(path) # join the list into a string +# # best_checkpoint = results.get_best_checkpoint(results.trials[0], mode="max") +# # print(f".. best checkpoint was: {best_checkpoint}") + +# # From here on, the relevant part to exporting the model +# new_trainer = PPOTrainer(config=exp["config"]) +# new_trainer.restore(path) +# # policy = new_trainer.get_policy() +# new_trainer.export_policy_model(export_dir=export_path, onnx=9) # This works for version 1.11.X + + +# Running with: gdrl --env_path envs/builds/JumperHard/jumper_hard.exe --export \ +# --restore envs/checkpoints/jumper_hard/checkpoint_000500/checkpoint-500 # model = policy.model # export the model to onnx using torch.onnx.export # dummy_input = torch.randn(1, 3, 84, 84) @@ -139,7 +141,7 @@ def rllib_training(args, extras): run_name = exp["algorithm"] + "/editor" print("run_name", run_name) - if args.num_gpus != None: + if args.num_gpus is not None: exp["config"]["num_gpus"] = args.num_gpus if args.env_path is None: @@ -147,7 +149,6 @@ def rllib_training(args, extras): exp["config"]["num_workers"] = 1 checkpoint_freq = 10 - checkpoint_at_end = True exp["config"]["env_config"]["show_window"] = args.viz exp["config"]["env_config"]["speedup"] = args.speedup @@ -170,7 +171,7 @@ def rllib_training(args, extras): ray.init(num_gpus=exp["config"]["num_gpus"] or 1) if not args.export: - results = tune.run( + tune.run( exp["algorithm"], name=run_name, config=exp["config"], @@ -185,6 +186,7 @@ def rllib_training(args, extras): else f"{trial.trainable_name}_{trial.trial_id}", ) if args.export: - rllib_export(args.restore) + raise NotImplementedError("Exporting is not (re)implemented yet") + # rllib_export(args.restore) ray.shutdown() diff --git a/godot_rl/wrappers/sample_factory_wrapper.py b/godot_rl/wrappers/sample_factory_wrapper.py index ea567e45..2c01a984 100644 --- a/godot_rl/wrappers/sample_factory_wrapper.py +++ b/godot_rl/wrappers/sample_factory_wrapper.py @@ -1,5 +1,4 @@ import argparse -import random from functools import partial import numpy as np diff --git a/godot_rl/wrappers/sbg_single_obs_wrapper.py b/godot_rl/wrappers/sbg_single_obs_wrapper.py index 560c01ef..d4136430 100644 --- a/godot_rl/wrappers/sbg_single_obs_wrapper.py +++ b/godot_rl/wrappers/sbg_single_obs_wrapper.py @@ -5,8 +5,11 @@ from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv -# A variant of the Stable Baselines Godot Env that only supports a single obs space from the dictionary - obs["obs"] by default. -# This provides some basic support for using envs that have a single obs space with policies other than MultiInputPolicy. +# A variant of the Stable Baselines Godot Env that only supports a single +# obs space from the dictionary - obs["obs"] by default. + +# This provides some basic support for using envs that have a single obs +# space with policies other than MultiInputPolicy. class SBGSingleObsEnv(StableBaselinesGodotEnv): diff --git a/godot_rl/wrappers/stable_baselines_wrapper.py b/godot_rl/wrappers/stable_baselines_wrapper.py index bccf1434..8d7493cf 100644 --- a/godot_rl/wrappers/stable_baselines_wrapper.py +++ b/godot_rl/wrappers/stable_baselines_wrapper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import gymnasium as gym import numpy as np diff --git a/setup.cfg b/setup.cfg index 985fcd20..b0d7eb2a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,3 +52,5 @@ rllib = cleanrl = wandb +[flake8] +ignore = E203, E501, E741, W503, W605 diff --git a/tests/test_action_space_preprocessor.py b/tests/test_action_space_preprocessor.py index e4f0acbf..a9650e2a 100644 --- a/tests/test_action_space_preprocessor.py +++ b/tests/test_action_space_preprocessor.py @@ -1,7 +1,6 @@ import pytest -from gymnasium.spaces import Box, Dict, Discrete, Tuple +from gymnasium.spaces import Box, Discrete, Tuple -from godot_rl.core.godot_env import GodotEnv from godot_rl.core.utils import ActionSpaceProcessor diff --git a/tests/test_call_method.py b/tests/test_call_method.py index 5a1d9fbb..cd0e78de 100644 --- a/tests/test_call_method.py +++ b/tests/test_call_method.py @@ -1,5 +1,3 @@ -import time - from godot_rl.core.godot_env import GodotEnv if __name__ == "__main__": diff --git a/tests/test_godot_env.py b/tests/test_godot_env.py index 9f5532a7..7d7e9361 100644 --- a/tests/test_godot_env.py +++ b/tests/test_godot_env.py @@ -18,7 +18,6 @@ def test_env_ij(env_name, port, n_agents): env = GodotEnv(env_path=env_path, port=port) action_space = env.action_space - observation_space = env.observation_space n_envs = env.num_envs for j in range(2): @@ -58,7 +57,6 @@ def test_env_ji(env_name, port, n_agents): env = GodotEnv(env_path=env_path, port=port) action_space = env.action_space - observation_space = env.observation_space n_envs = env.num_envs assert n_envs == n_agents for j in range(2):