diff --git a/examples/stable_baselines3_example.py b/examples/stable_baselines3_example.py index b19b0a9d..630f2044 100644 --- a/examples/stable_baselines3_example.py +++ b/examples/stable_baselines3_example.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env.vec_monitor import VecMonitor from godot_rl.core.utils import can_import -from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx +from godot_rl.wrappers.onnx.stable_baselines_export import export_model_as_onnx from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv # To download the env source and binary: @@ -115,7 +115,7 @@ def handle_onnx_export(): if args.onnx_export_path is not None: path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx") print("Exporting onnx to: " + os.path.abspath(path_onnx)) - export_ppo_model_as_onnx(model, str(path_onnx)) + export_model_as_onnx(model, str(path_onnx)) def handle_model_save(): diff --git a/godot_rl/wrappers/onnx/stable_baselines_export.py b/godot_rl/wrappers/onnx/stable_baselines_export.py index c7225d04..4573b77e 100644 --- a/godot_rl/wrappers/onnx/stable_baselines_export.py +++ b/godot_rl/wrappers/onnx/stable_baselines_export.py @@ -1,17 +1,18 @@ import torch from gymnasium.vector.utils import spaces -from stable_baselines3 import PPO +from stable_baselines3 import PPO, SAC -class OnnxableMultiInputPolicy(torch.nn.Module): +class OnnxablePolicy(torch.nn.Module): def __init__( self, - obs_keys, - features_extractor, - mlp_extractor, - action_net, - value_net, - use_obs_array, + obs_keys=None, + features_extractor=None, + mlp_extractor=None, + action_net=None, + value_net=None, + use_obs_array=None, + actor=None, ): super().__init__() self.obs_keys = obs_keys @@ -20,10 +21,12 @@ def __init__( self.action_net = action_net self.value_net = value_net self.use_obs_array = use_obs_array + self.actor = actor - def forward(self, obs, state_ins): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` + def forward_sac(self, observation: torch.Tensor, state_ins): + return self.actor(observation, deterministic=True), state_ins + + def forward_ppo(self, obs, state_ins): features = None if self.use_obs_array: @@ -35,31 +38,47 @@ def forward(self, obs, state_ins): action_hidden, value_hidden = self.mlp_extractor(features) return self.action_net(action_hidden), state_ins + def forward(self, obs, state_ins): + if self.actor: + return self.forward_sac(obs, state_ins) + else: + return self.forward_ppo(obs, state_ins) + + +def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = False): + policy = model.policy.to("cpu") + dummy_input = None + onnxable_model = None + + if isinstance(model, SAC): + assert use_obs_array, "SAC ONNX export works with use_obs_array=True, MLPPolicy and SBGSingleObsEnv only." + + if isinstance(model, PPO): + onnxable_model = OnnxablePolicy( + ["obs"], + policy.features_extractor, + policy.mlp_extractor, + policy.action_net, + policy.value_net, + use_obs_array, + ) + if use_obs_array: + dummy_input = torch.unsqueeze(torch.tensor(model.observation_space.sample()), 0) + else: + dummy_input = dict(model.observation_space.sample()) + for k, v in dummy_input.items(): + dummy_input[k] = torch.from_numpy(v).unsqueeze(0) + dummy_input = [v for v in dummy_input.values()] -def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool = False): - ppo_policy = ppo.policy.to("cpu") - onnxable_model = OnnxableMultiInputPolicy( - ["obs"], - ppo_policy.features_extractor, - ppo_policy.mlp_extractor, - ppo_policy.action_net, - ppo_policy.value_net, - use_obs_array, - ) - - if use_obs_array: - dummy_input = torch.unsqueeze(torch.tensor(ppo.observation_space.sample()), 0) - else: - dummy_input = dict(ppo.observation_space.sample()) - for k, v in dummy_input.items(): - dummy_input[k] = torch.from_numpy(v).unsqueeze(0) - dummy_input = [v for v in dummy_input.values()] + elif isinstance(model, SAC): + onnxable_model = OnnxablePolicy(actor=model.policy.actor) + dummy_input = torch.randn(1, *model.observation_space.shape) torch.onnx.export( onnxable_model, args=(dummy_input, torch.zeros(1).float()), f=onnx_model_path, - opset_version=9, + opset_version=17, input_names=["obs", "state_ins"], output_names=["output", "state_outs"], dynamic_axes={ @@ -70,11 +89,14 @@ def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool }, ) - # If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch - # (the output from onnx will be the action logits for each discrete action, - # while the output from sb3 will be a single int) - if not isinstance(ppo.action_space, spaces.MultiDiscrete): - verify_onnx_export(ppo, onnx_model_path, use_obs_array=use_obs_array) + # We only verify with PPO currently due to different output shape with SAC + # (this can be updated in the future) + if isinstance(model, PPO): + # If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch + # (the output from onnx will be the action logits for each discrete action, + # while the output from sb3 will be a single int) + if not isinstance(model.action_space, spaces.MultiDiscrete): + verify_onnx_export(model, onnx_model_path, use_obs_array=use_obs_array) def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_array: bool = False): diff --git a/tests/test_sb3_onnx_export.py b/tests/test_sb3_onnx_export.py index baeea9b9..c05845fa 100644 --- a/tests/test_sb3_onnx_export.py +++ b/tests/test_sb3_onnx_export.py @@ -24,7 +24,7 @@ def test_pytorch_vs_onnx(env_name, port): from stable_baselines3 import PPO - from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx, verify_onnx_export + from godot_rl.wrappers.onnx.stable_baselines_export import export_model_as_onnx, verify_onnx_export from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64" @@ -39,6 +39,6 @@ def test_pytorch_vs_onnx(env_name, port): tensorboard_log="logs/log", ) - export_ppo_model_as_onnx(ppo, f"{env_name}_tmp.onnx") + export_model_as_onnx(ppo, f"{env_name}_tmp.onnx") verify_onnx_export(ppo, f"{env_name}_tmp.onnx") os.remove(f"{env_name}_tmp.onnx")