-
Notifications
You must be signed in to change notification settings - Fork 161
/
recorder.py
74 lines (62 loc) · 2.49 KB
/
recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import pickle
from collections import defaultdict
import numpy as np
from baselines import logger
from mpi4py import MPI
def is_square(n):
return n == (int(np.sqrt(n))) ** 2
class Recorder(object):
def __init__(self, nenvs, score_multiple=1):
self.episodes = [defaultdict(list) for _ in range(nenvs)]
self.total_episodes = 0
self.filename = self.get_filename()
self.score_multiple = score_multiple
self.all_scores = {}
self.all_places = {}
def record(self, bufs, infos):
for env_id, ep_infos in enumerate(infos):
left_step = 0
done_steps = sorted(ep_infos.keys())
for right_step in done_steps:
for key in bufs:
self.episodes[env_id][key].append(bufs[key][env_id, left_step:right_step].copy())
self.record_episode(env_id, ep_infos[right_step])
left_step = right_step
for key in bufs:
self.episodes[env_id][key].clear()
for key in bufs:
self.episodes[env_id][key].append(bufs[key][env_id, left_step:].copy())
def record_episode(self, env_id, info):
self.total_episodes += 1
if self.episode_worth_saving(env_id, info):
episode = {}
for key in self.episodes[env_id]:
episode[key] = np.concatenate(self.episodes[env_id][key])
info['env_id'] = env_id
episode['info'] = info
with open(self.filename, 'ab') as f:
pickle.dump(episode, f, protocol=-1)
def get_score(self, info):
return int(info['r']/self.score_multiple) * self.score_multiple
def episode_worth_saving(self, env_id, info):
if self.score_multiple is None:
return False
r = self.get_score(info)
if r not in self.all_scores:
self.all_scores[r] = 0
else:
self.all_scores[r] += 1
hashable_places = tuple(sorted(info['places']))
if hashable_places not in self.all_places:
self.all_places[hashable_places] = 0
else:
self.all_places[hashable_places] += 1
if is_square(self.all_scores[r]) or is_square(self.all_places[hashable_places]):
return True
if 15 in info['places']:
return True
return False
def get_filename(self):
filename = os.path.join(logger.get_dir(), 'videos_{}.pk'.format(MPI.COMM_WORLD.Get_rank()))
return filename