-
Notifications
You must be signed in to change notification settings - Fork 464
/
a2c.py
187 lines (155 loc) · 5.68 KB
/
a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import numpy as np
# Hyperparameters
n_train_processes = 3
learning_rate = 0.0002
update_interval = 5
gamma = 0.98
max_train_steps = 60000
PRINT_INTERVAL = update_interval * 100
class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.fc1 = nn.Linear(4, 256)
self.fc_pi = nn.Linear(256, 2)
self.fc_v = nn.Linear(256, 1)
def pi(self, x, softmax_dim=1):
x = F.relu(self.fc1(x))
x = self.fc_pi(x)
prob = F.softmax(x, dim=softmax_dim)
return prob
def v(self, x):
x = F.relu(self.fc1(x))
v = self.fc_v(x)
return v
def worker(worker_id, master_end, worker_end):
master_end.close() # Forbid worker to use the master end for messaging
env = gym.make('CartPole-v1')
env.seed(worker_id)
while True:
cmd, data = worker_end.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
worker_end.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
worker_end.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()
worker_end.send(ob)
elif cmd == 'close':
worker_end.close()
break
elif cmd == 'get_spaces':
worker_end.send((env.observation_space, env.action_space))
else:
raise NotImplementedError
class ParallelEnv:
def __init__(self, n_train_processes):
self.nenvs = n_train_processes
self.waiting = False
self.closed = False
self.workers = list()
master_ends, worker_ends = zip(*[mp.Pipe() for _ in range(self.nenvs)])
self.master_ends, self.worker_ends = master_ends, worker_ends
for worker_id, (master_end, worker_end) in enumerate(zip(master_ends, worker_ends)):
p = mp.Process(target=worker,
args=(worker_id, master_end, worker_end))
p.daemon = True
p.start()
self.workers.append(p)
# Forbid master to use the worker end for messaging
for worker_end in worker_ends:
worker_end.close()
def step_async(self, actions):
for master_end, action in zip(self.master_ends, actions):
master_end.send(('step', action))
self.waiting = True
def step_wait(self):
results = [master_end.recv() for master_end in self.master_ends]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for master_end in self.master_ends:
master_end.send(('reset', None))
return np.stack([master_end.recv() for master_end in self.master_ends])
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def close(self): # For clean up resources
if self.closed:
return
if self.waiting:
[master_end.recv() for master_end in self.master_ends]
for master_end in self.master_ends:
master_end.send(('close', None))
for worker in self.workers:
worker.join()
self.closed = True
def test(step_idx, model):
env = gym.make('CartPole-v1')
score = 0.0
done = False
num_test = 10
for _ in range(num_test):
s = env.reset()
while not done:
prob = model.pi(torch.from_numpy(s).float(), softmax_dim=0)
a = Categorical(prob).sample().numpy()
s_prime, r, done, info = env.step(a)
s = s_prime
score += r
done = False
print(f"Step # :{step_idx}, avg score : {score/num_test:.1f}")
env.close()
def compute_target(v_final, r_lst, mask_lst):
G = v_final.reshape(-1)
td_target = list()
for r, mask in zip(r_lst[::-1], mask_lst[::-1]):
G = r + gamma * G * mask
td_target.append(G)
return torch.tensor(td_target[::-1]).float()
if __name__ == '__main__':
envs = ParallelEnv(n_train_processes)
model = ActorCritic()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
step_idx = 0
s = envs.reset()
while step_idx < max_train_steps:
s_lst, a_lst, r_lst, mask_lst = list(), list(), list(), list()
for _ in range(update_interval):
prob = model.pi(torch.from_numpy(s).float())
a = Categorical(prob).sample().numpy()
s_prime, r, done, info = envs.step(a)
s_lst.append(s)
a_lst.append(a)
r_lst.append(r/100.0)
mask_lst.append(1 - done)
s = s_prime
step_idx += 1
s_final = torch.from_numpy(s_prime).float()
v_final = model.v(s_final).detach().clone().numpy()
td_target = compute_target(v_final, r_lst, mask_lst)
td_target_vec = td_target.reshape(-1)
s_vec = torch.tensor(s_lst).float().reshape(-1, 4) # 4 == Dimension of state
a_vec = torch.tensor(a_lst).reshape(-1).unsqueeze(1)
advantage = td_target_vec - model.v(s_vec).reshape(-1)
pi = model.pi(s_vec, softmax_dim=1)
pi_a = pi.gather(1, a_vec).reshape(-1)
loss = -(torch.log(pi_a) * advantage.detach()).mean() +\
F.smooth_l1_loss(model.v(s_vec).reshape(-1), td_target_vec)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step_idx % PRINT_INTERVAL == 0:
test(step_idx, model)
envs.close()