Skip to content

Commit

Permalink
Merge pull request #101 from edbeeching/load_onnx
Browse files Browse the repository at this point in the history
Adds sb3 onnx export
  • Loading branch information
edbeeching authored Apr 28, 2023
2 parents 312fc0d + b7b8266 commit 448f50a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 2 deletions.
Binary file added BallChase.zip
Binary file not shown.
13 changes: 11 additions & 2 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from stable_baselines3 import PPO

# To download the env source and binary:
Expand All @@ -11,7 +12,12 @@
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--env_path",
# default="envs/example_envs/builds/JumperHard/jumper_hard.x86_64",
default=None,
type=str,
help="The Godot binary to use, do not include for in editor training",
)
parser.add_argument(
"--onnx_export_path",
default=None,
type=str,
help="The Godot binary to use, do not include for in editor training",
Expand All @@ -22,10 +28,13 @@
args, extras = parser.parse_known_args()


env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup, convert_action_space=True)
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup)

model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log="logs/log")
model.learn(1000000)

print("closing env")
env.close()

if args.onnx_export_path is not None:
export_ppo_model_as_onnx(model, args.onnx_export_path)
Empty file.
77 changes: 77 additions & 0 deletions godot_rl/wrappers/onnx/stable_baselines_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
from stable_baselines3 import PPO


class OnnxableMultiInputPolicy(torch.nn.Module):
def __init__(self, obs_keys, features_extractor, mlp_extractor, action_net, value_net):
super().__init__()
self.obs_keys = obs_keys
self.features_extractor = features_extractor
self.mlp_extractor = mlp_extractor
self.action_net = action_net
self.value_net = value_net

def forward(self, obs, state_ins):
obs_dict = {k: v for k, v in zip(self.obs_keys, obs)}
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
features = self.features_extractor(obs_dict)
action_hidden, value_hidden = self.mlp_extractor(features)
return self.action_net(action_hidden), state_ins


def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str):
ppo_policy = ppo.policy.to("cpu")
onnxable_model = OnnxableMultiInputPolicy(
["obs"],
ppo_policy.features_extractor,
ppo_policy.mlp_extractor,
ppo_policy.action_net,
ppo_policy.value_net,
)
dummy_input = dict(ppo.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)

dummy_input = [v for v in dummy_input.values()]
torch.onnx.export(
onnxable_model,
args=(dummy_input, torch.zeros(1).float()),
f=onnx_model_path,
opset_version=9,
input_names=["obs", "state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={'obs' : {0 : 'batch_size'},
'state_ins' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'},
'state_outs' : {0 : 'batch_size'}}

)
verify_onnx_export(ppo, onnx_model_path)


def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10):
import numpy as np
import onnx
import onnxruntime as ort

onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)

sb3_model = ppo.policy.to("cpu")
ort_sess = ort.InferenceSession(onnx_model_path)

for i in range(num_tests):
obs = dict(ppo.observation_space.sample())

obs2 = {}
for k, v in obs.items():
obs2[k] = torch.from_numpy(v).unsqueeze(0)

with torch.no_grad():
action_sb3, _, _ = sb3_model(obs2, deterministic=True)

obs = [v for v in obs.values()]
action_onnx, state_outs = ort_sess.run(None, {"obs": obs, "state_ins": np.array([0.0], dtype=np.float32)})
assert np.allclose(action_sb3, action_onnx, atol=1e-5), "Mismatch in action output"
assert np.allclose(state_outs, np.array([0.0]), atol=1e-5), "Mismatch in state_outs output"
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ install_requires =
gym==0.26.2
stable-baselines3
huggingface_sb3
onnx
onnxruntime

python_requires = >=3.8
zip_safe = no
Expand Down
36 changes: 36 additions & 0 deletions tests/test_sb3_onnx_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

import pytest
from stable_baselines3 import PPO

from godot_rl.wrappers.onnx.stable_baselines_export import (
export_ppo_model_as_onnx, verify_onnx_export)
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv


@pytest.mark.parametrize(
"env_name,port",
[
("BallChase", 12008),
("FPS", 12009),
("JumperHard", 12010),
("Racer", 12011),
("FlyBy", 12012),
],
)
def test_pytorch_vs_onnx(env_name, port):
env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
env = StableBaselinesGodotEnv(env_path, port=port)

ppo = PPO(
"MultiInputPolicy",
env,
ent_coef=0.0001,
verbose=2,
n_steps=32,
tensorboard_log="logs/log",
)

export_ppo_model_as_onnx(ppo, f"{env_name}_tmp.onnx")
verify_onnx_export(ppo, f"{env_name}_tmp.onnx")
os.remove(f"{env_name}_tmp.onnx")

0 comments on commit 448f50a

Please sign in to comment.