-
Notifications
You must be signed in to change notification settings - Fork 17
/
enjoy.py
71 lines (58 loc) · 1.95 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import pickle
import torch
from docopt import docopt
from model import ActorCriticModel
from utils import create_env
def main():
# Command line arguments via docopt
_USAGE = """
Usage:
enjoy.py [options]
enjoy.py --help
Options:
--model=<path> Specifies the path to the trained model [default: ./models/minigrid.nn].
"""
options = docopt(_USAGE)
model_path = options["--model"]
# Inference device
device = torch.device("cpu")
torch.set_default_tensor_type("torch.FloatTensor")
# Load model and config
state_dict, config = pickle.load(open(model_path, "rb"))
# Instantiate environment
env = create_env(config["environment"], render=True)
# Initialize model and load its parameters
model = ActorCriticModel(config, env.observation_space, (env.action_space.n,))
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# Run and render episode
done = False
episode_rewards = []
# Init recurrent cell
hxs, cxs = model.init_recurrent_cell_states(1, device)
if config["recurrence"]["layer_type"] == "gru":
recurrent_cell = hxs
elif config["recurrence"]["layer_type"] == "lstm":
recurrent_cell = (hxs, cxs)
obs = env.reset()
while not done:
# Render environment
env.render()
# Forward model
policy, value, recurrent_cell = model(torch.tensor(np.expand_dims(obs, 0)), recurrent_cell, device, 1)
# Sample action
action = []
for action_branch in policy:
action.append(action_branch.sample().item())
# Step environment
obs, reward, done, info = env.step(action)
episode_rewards.append(reward)
# After done, render last state
env.render()
print("Episode length: " + str(info["length"]))
print("Episode reward: " + str(info["reward"]))
env.close()
if __name__ == "__main__":
main()