Skip to content

Commit

Permalink
Reformat clean_rl_example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-267 committed Jan 18, 2024
1 parent e023c0f commit 034ce4f
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions examples/clean_rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def get_action_and_value(self, x, action=None):

# env setup

envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel)
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed,
n_parallel=args.n_parallel)
args.num_envs = envs.num_envs
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
Expand Down Expand Up @@ -333,6 +334,7 @@ def get_action_and_value(self, x, action=None):

agent.eval().to("cpu")


class OnnxPolicy(torch.nn.Module):
def __init__(self, actor_mean):
super().__init__()
Expand All @@ -342,6 +344,7 @@ def forward(self, obs, state_ins):
action_mean = self.actor_mean(obs)
return action_mean, state_ins


onnx_policy = OnnxPolicy(agent.actor_mean)
dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0)

Expand All @@ -352,9 +355,9 @@ def forward(self, obs, state_ins):
opset_version=15,
input_names=["obs", "state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={'obs' : {0 : 'batch_size'},
'state_ins' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'},
'state_outs' : {0 : 'batch_size'}}
dynamic_axes={'obs': {0: 'batch_size'},
'state_ins': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'},
'state_outs': {0: 'batch_size'}}

)
)

0 comments on commit 034ce4f

Please sign in to comment.