-
Notifications
You must be signed in to change notification settings - Fork 0
/
a2c.py
266 lines (205 loc) · 8.46 KB
/
a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# A2C
from contextlib import contextmanager
from collections import deque
import numpy as np
import time
import os
import gym
import pprint as pp
import pickle
import sys
from rlkits.sampler import ParallelEnvTrajectorySampler
from rlkits.sampler import estimate_Q
from rlkits.sampler import aggregate_experience
from rlkits.policies import PolicyWithValue
import rlkits.utils as U
import rlkits.utils.logger as logger
from rlkits.utils.math import explained_variance
from rlkits.env_batch import ParallelEnvBatch
from rlkits.env_wrappers import AutoReset, StartWithRandomActions
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
def compute_loss(pi, trajectory, log_dir):
"""Compute loss for policy and value net"""
obs = trajectory['obs']
obs = torch.from_numpy(obs)
dist = pi.dist(pi.policy_net(obs))
if dist is None:
logger.log('Got Nan -- Bad')
pi.save_ckpt()
args = {
"trajectory":trajectory
}
with open(os.path.join(log_dir, 'local.pkl'), 'wb') as f:
pickle.dump(args, f)
sys.exit()
actions = torch.from_numpy(trajectory['actions'])
log_prob = dist.log_prob(actions)
adv = torch.from_numpy(trajectory['adv'])
if len(log_prob.shape) > 1:
log_prob = log_prob.squeeze(dim=1)
assert log_prob.shape == adv.shape, f"log_prob shape: {log_prob.shape}, adv shape : {adv.shape}"
pi_loss = -log_prob * adv
vpreds = pi.value_net(obs)
Q = torch.from_numpy(trajectory['Q'])
if len(vpreds.shape) > 1:
vpreds = vpreds.squeeze(dim=1)
assert vpreds.shape == Q.shape, f"vpreds shape: {vpreds.shape}, Q shape : {Q.shape}"
v_loss = F.mse_loss(vpreds, Q)
return {
"pi_loss" : pi_loss.mean(),
"v_loss" : v_loss.mean(),
"entropy" : dist.entropy().mean()
}
def sync_policies(oldpi, pi):
# oldpi <- pi
oldpi.policy_net.load_state_dict(pi.policy_net.state_dict())
oldpi.value_net.load_state_dict(pi.value_net.state_dict())
return
def policy_diff(oldpi, pi):
"""Compute the average distance between params of oldpi and pi"""
diff = 0.0
cnt = 0
for p1, p2 in zip(oldpi.policy_net.parameters(), pi.policy_net.parameters()):
diff += torch.mean(torch.abs(p1.data - p2.data))
cnt +=1
return diff / cnt
def A2C(
env,
nsteps,
gamma,
total_timesteps,
pi_lr,
v_lr,
ent_coef,
log_interval,
max_grad_norm,
reward_transform,
ckpt_dir,
**network_kwargs
):
"""A2C algorithm
Args:
env (ParallelEnv): gym environment name
nenvs (int): number of parallel envs
nsteps (int): length of parallel trajectory to be sampled from the env
gamma (float): discount factor
total_timesteps (int): total number of frames to be sampled from the env
pi_lr (float): policy learning rate
v_lr (float): value net learning rate
ent_coef (float): entropy coefficient
log_interval (int): number of training steps between each checkpointing
max_grad_norm (float, None): max norm of the gradients for each layer of the policy network.
For example, if `max_grad_norm=0.1`, then the L2 norm of the gradient tensor for
each layer is capped at 0.1. If `max_grad_norm = None`, then no gradient clip is applied.
reward_transform (callable): a callback to be applied to the reward for each step of experienece sampling.
ckpt_dir (str): directory to save the log and checkpoint
clip_episode (bool): whether to enforce hard stop on non-ending environment. Certain environments,
e.g. Pendulum, do not have a stopping signal, the episode would run forever. If `clip_episode=True`,
then the env runs for at most 200 steps
Returns:
None
"""
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
logger.configure(dir=ckpt_dir)
ob_space = env.observation_space
ac_space = env.action_space
pi = PolicyWithValue(ob_space=ob_space,
ac_space=ac_space, ckpt_dir=ckpt_dir,
**network_kwargs)
# only used to compute policy difference
oldpi = PolicyWithValue(ob_space=ob_space,
ac_space=ac_space, ckpt_dir=ckpt_dir,
**network_kwargs)
poptimizer = optim.Adam(pi.policy_net.parameters(),
lr=pi_lr)
voptimizer = optim.Adam(pi.value_net.parameters(),
lr=v_lr)
sampler = ParallelEnvTrajectorySampler(env, pi, nsteps,
reward_transform=reward_transform, gamma=gamma)
# moving average of last 10 episode returns
rolling_buf_episode_rets = deque(maxlen=10)
# moving average of last 10 episode length
rolling_buf_episode_lens = deque(maxlen=10)
nframes = env.nenvs * nsteps # number of frames processed by update iter
nupdates = total_timesteps // nframes
start = time.perf_counter()
best_ret = np.float('-inf')
for update in range(1, nupdates+1):
sync_policies(oldpi, pi)
tstart = time.perf_counter()
trajectory = sampler(callback=estimate_Q)
# aggregate exps from parallel envs
for k, v in trajectory.items():
if isinstance(v, np.ndarray):
trajectory[k] = aggregate_experience(v)
#trajectory = shuffle_experience(trajectory)
adv = trajectory['Q'] - trajectory['vpreds']
trajectory['adv'] = (adv - adv.mean())/adv.std()
losses = compute_loss(pi=pi,
trajectory=trajectory,
log_dir=ckpt_dir)
frac = 1.0 - (update - 1.0)/nupdates
loss = losses['pi_loss'] + losses['v_loss'] \
- ent_coef * frac * losses['entropy']
poptimizer.zero_grad()
voptimizer.zero_grad()
loss.backward()
clip_grad_norm_(pi.policy_net.parameters(),
max_norm=max_grad_norm)
clip_grad_norm_(pi.value_net.parameters(),
max_norm=max_grad_norm)
poptimizer.step()
voptimizer.step()
tnow = time.perf_counter()
# logging
if update % log_interval == 0 or update==1:
fps = int(nframes / (tnow - tstart)) # frames per seconds
logger.record_tabular('iteration/nupdates',
f"{update}/{nupdates}")
logger.record_tabular('frac', frac)
logger.record_tabular('policy_loss',
losses['pi_loss'].detach().numpy())
logger.record_tabular('value_loss',
losses['v_loss'].detach().numpy())
logger.record_tabular('entropy',
losses['entropy'].detach().numpy())
for ep_rets in trajectory['ep_rets']:
rolling_buf_episode_rets.extend(ep_rets)
for ep_lens in trajectory['ep_lens']:
rolling_buf_episode_lens.extend(ep_lens)
step_size = policy_diff(oldpi, pi)
logger.record_tabular('step_size', step_size.numpy())
# explained variance
ev = explained_variance(trajectory['vpreds'], trajectory['Q'])
logger.record_tabular('explained_variance', ev)
piw, vw = pi.average_weight()
logger.record_tabular('policy_net_weight', piw.numpy())
logger.record_tabular('value_net_weight', vw.numpy())
vqdiff = np.mean((trajectory['Q'] - trajectory['vpreds'])**2)
logger.record_tabular('vqdiff', vqdiff)
logger.record_tabular('Q', np.mean(trajectory['Q']))
logger.record_tabular('vpreds', np.mean(trajectory['vpreds']))
logger.record_tabular('FPS', fps)
ret = safemean(rolling_buf_episode_rets)
logger.record_tabular("ma_ep_ret", ret)
logger.record_tabular('ma_ep_len',
safemean(rolling_buf_episode_lens))
logger.record_tabular('mean_rew_step',
np.mean(trajectory['rews']))
if ret != np.nan and ret > best_ret:
best_ret = ret
pi.save_ckpt('best')
logger.dump_tabular()
pi.save_ckpt('last')
torch.save(poptimizer, os.path.join(ckpt_dir, 'optim.pth'))
torch.save(voptimizer, os.path.join(ckpt_dir, 'optim.pth'))
end = time.perf_counter()
env.close()
logger.log(f"Total time elapsed: {end - start}")
return
def safemean(l):
return np.nan if len(l) == 0 else np.mean(l)