diff --git a/examples/clean_rl_example.py b/examples/clean_rl_example.py index def9a64a..8b061fb8 100644 --- a/examples/clean_rl_example.py +++ b/examples/clean_rl_example.py @@ -167,7 +167,8 @@ def get_action_and_value(self, x, action=None): # env setup - envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel) + envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, + n_parallel=args.n_parallel) args.num_envs = envs.num_envs args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) @@ -333,6 +334,7 @@ def get_action_and_value(self, x, action=None): agent.eval().to("cpu") + class OnnxPolicy(torch.nn.Module): def __init__(self, actor_mean): super().__init__() @@ -342,6 +344,7 @@ def forward(self, obs, state_ins): action_mean = self.actor_mean(obs) return action_mean, state_ins + onnx_policy = OnnxPolicy(agent.actor_mean) dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0) @@ -352,9 +355,9 @@ def forward(self, obs, state_ins): opset_version=15, input_names=["obs", "state_ins"], output_names=["output", "state_outs"], - dynamic_axes={'obs' : {0 : 'batch_size'}, - 'state_ins' : {0 : 'batch_size'}, # variable length axes - 'output' : {0 : 'batch_size'}, - 'state_outs' : {0 : 'batch_size'}} + dynamic_axes={'obs': {0: 'batch_size'}, + 'state_ins': {0: 'batch_size'}, # variable length axes + 'output': {0: 'batch_size'}, + 'state_outs': {0: 'batch_size'}} - ) \ No newline at end of file + )