Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sb3 sac onnx export #198

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor

from godot_rl.core.utils import can_import
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from godot_rl.wrappers.onnx.stable_baselines_export import export_model_as_onnx
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

# To download the env source and binary:
Expand Down Expand Up @@ -115,7 +115,7 @@ def handle_onnx_export():
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))
export_model_as_onnx(model, str(path_onnx))


def handle_model_save():
Expand Down
92 changes: 57 additions & 35 deletions godot_rl/wrappers/onnx/stable_baselines_export.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch
from gymnasium.vector.utils import spaces
from stable_baselines3 import PPO
from stable_baselines3 import PPO, SAC


class OnnxableMultiInputPolicy(torch.nn.Module):
class OnnxablePolicy(torch.nn.Module):
def __init__(
self,
obs_keys,
features_extractor,
mlp_extractor,
action_net,
value_net,
use_obs_array,
obs_keys=None,
features_extractor=None,
mlp_extractor=None,
action_net=None,
value_net=None,
use_obs_array=None,
actor=None,
):
super().__init__()
self.obs_keys = obs_keys
Expand All @@ -20,10 +21,12 @@ def __init__(
self.action_net = action_net
self.value_net = value_net
self.use_obs_array = use_obs_array
self.actor = actor

def forward(self, obs, state_ins):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
def forward_sac(self, observation: torch.Tensor, state_ins):
return self.actor(observation, deterministic=True), state_ins

def forward_ppo(self, obs, state_ins):
features = None

if self.use_obs_array:
Expand All @@ -35,31 +38,47 @@ def forward(self, obs, state_ins):
action_hidden, value_hidden = self.mlp_extractor(features)
return self.action_net(action_hidden), state_ins

def forward(self, obs, state_ins):
if self.actor:
return self.forward_sac(obs, state_ins)
else:
return self.forward_ppo(obs, state_ins)


def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = False):
policy = model.policy.to("cpu")
dummy_input = None
onnxable_model = None

if isinstance(model, SAC):
assert use_obs_array, "SAC ONNX export works with use_obs_array=True, MLPPolicy and SBGSingleObsEnv only."

if isinstance(model, PPO):
onnxable_model = OnnxablePolicy(
["obs"],
policy.features_extractor,
policy.mlp_extractor,
policy.action_net,
policy.value_net,
use_obs_array,
)
if use_obs_array:
dummy_input = torch.unsqueeze(torch.tensor(model.observation_space.sample()), 0)
else:
dummy_input = dict(model.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)
dummy_input = [v for v in dummy_input.values()]

def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool = False):
ppo_policy = ppo.policy.to("cpu")
onnxable_model = OnnxableMultiInputPolicy(
["obs"],
ppo_policy.features_extractor,
ppo_policy.mlp_extractor,
ppo_policy.action_net,
ppo_policy.value_net,
use_obs_array,
)

if use_obs_array:
dummy_input = torch.unsqueeze(torch.tensor(ppo.observation_space.sample()), 0)
else:
dummy_input = dict(ppo.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)
dummy_input = [v for v in dummy_input.values()]
elif isinstance(model, SAC):
onnxable_model = OnnxablePolicy(actor=model.policy.actor)
dummy_input = torch.randn(1, *model.observation_space.shape)

torch.onnx.export(
onnxable_model,
args=(dummy_input, torch.zeros(1).float()),
f=onnx_model_path,
opset_version=9,
opset_version=17,
input_names=["obs", "state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={
Expand All @@ -70,11 +89,14 @@ def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool
},
)

# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
# (the output from onnx will be the action logits for each discrete action,
# while the output from sb3 will be a single int)
if not isinstance(ppo.action_space, spaces.MultiDiscrete):
verify_onnx_export(ppo, onnx_model_path, use_obs_array=use_obs_array)
# We only verify with PPO currently due to different output shape with SAC
# (this can be updated in the future)
if isinstance(model, PPO):
# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
# (the output from onnx will be the action logits for each discrete action,
# while the output from sb3 will be a single int)
if not isinstance(model.action_space, spaces.MultiDiscrete):
verify_onnx_export(model, onnx_model_path, use_obs_array=use_obs_array)


def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_array: bool = False):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sb3_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def test_pytorch_vs_onnx(env_name, port):
from stable_baselines3 import PPO

from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx, verify_onnx_export
from godot_rl.wrappers.onnx.stable_baselines_export import export_model_as_onnx, verify_onnx_export
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
Expand All @@ -39,6 +39,6 @@ def test_pytorch_vs_onnx(env_name, port):
tensorboard_log="logs/log",
)

export_ppo_model_as_onnx(ppo, f"{env_name}_tmp.onnx")
export_model_as_onnx(ppo, f"{env_name}_tmp.onnx")
verify_onnx_export(ppo, f"{env_name}_tmp.onnx")
os.remove(f"{env_name}_tmp.onnx")
Loading