-
Notifications
You must be signed in to change notification settings - Fork 161
/
load_log.py
156 lines (136 loc) · 5.39 KB
/
load_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import numpy as np
import os
import pickle
import exptag
separator = '------------'
def parse(key, value):
value = value.strip()
try:
if value.startswith('['):
value = value[1:-1].split(';')
if value != ['']:
try:
value = [int(v) for v in value]
except:
value = [str(v) for v in value]
else:
value = []
elif ';' in value:
value = 0.
elif value in ['nan', '', '-inf']:
value = np.nan
else:
value = eval(value)
except:
import ipdb; ipdb.set_trace()
print(f"failed to parse value {key}:{value.__repr__()}")
value = 0.
return value
def get_hash(filename):
import hashlib
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def pickle_cache_result(f):
def cached_f(filename):
current_hash = get_hash(filename)
cache_filename = filename + '_cache'
if os.path.exists(cache_filename):
with open(cache_filename, 'rb') as fl:
try:
stored_hash, stored_result = pickle.load(fl)
if stored_hash == current_hash:
# pass
return stored_result
except:
pass
result = f(filename)
with open(cache_filename, 'wb') as fl:
pickle.dump((current_hash, result), fl)
return result
return cached_f
@pickle_cache_result
def parse_csv(filename):
import csv
timeseries = {}
keys = []
with open(filename, 'r') as f:
reader = csv.reader(f)
for i, row in enumerate(reader):
if i == 0:
keys = row
else:
for column, key in enumerate(keys):
if key not in timeseries:
timeseries[key] = []
timeseries[key].append(parse(key, row[column]))
print(f'parsing {filename}')
if 'opt_featvar' in timeseries:
timeseries['opt_feat_var'] = timeseries['opt_featvar']
return timeseries
def get_filename_from_tag(tag):
folder = exptag.get_last_experiment_folder_by_tag(tag)
return os.path.join(folder, "progress.csv")
def get_filenames_from_tags(tags):
return [get_filename_from_tag(tag) for tag in tags]
def get_timeseries_from_filenames(filenames):
return [parse_csv(f) for f in filenames]
def get_timeseries_from_tags(tags):
return get_timeseries_from_filenames(get_filenames_from_tags(tags))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tags', type=lambda x: x.split(','), nargs='+', default=None)
parser.add_argument('--x_axis', type=str, choices=['tcount', 'n_updates'], default='tcount')
args = parser.parse_args()
x_axis = args.x_axis
timeseries_groups = []
for tag_group in args.tags:
timeseries_groups.append(get_timeseries_from_tags(tags=tag_group))
for tag_group, timeseries in zip(args.tags, timeseries_groups):
rooms = []
for tag, t in zip(tag_group, timeseries):
if 'rooms' in t:
rooms.append((tag, t['rooms'][-1], t['best_ret'][-1], t[x_axis][-1]))
else:
print(f"tag {tag} has no rooms")
rooms = sorted(rooms, key=lambda x: len(x[1]))
all_rooms = set.union(*[set(r[1]) for r in rooms])
import pprint
for tag, r, best_ret, max_x in rooms:
print(f'{tag}:{best_ret}:{r}(@{max_x})')
pprint.pprint(all_rooms)
keys = set.intersection(*[set(t.keys()) for t in sum(timeseries_groups, [])])
keys = sorted(list(keys))
import matplotlib.pyplot as plt
n_rows = int(np.ceil(np.sqrt(len(keys))))
n_cols = len(keys) // n_rows + 1
fig, axes = plt.subplots(n_rows, n_cols, sharex=True)
for i in range(n_rows):
for j in range(n_cols):
ind = i * n_cols + j
if ind < len(keys):
key = keys[ind]
for color, timeseries in zip('rgbykcm', timeseries_groups):
if key in ['all_places', 'recent_places', 'global_all_places', 'global_recent_places', 'rooms']:
if any(isinstance(t[key][-1], list) for t in timeseries):
for t in timeseries:
t[key] = list(map(len, t[key]))
max_timesteps = min((len(_[x_axis]) for _ in timeseries))
try:
data = np.asarray([t[key][:max_timesteps] for t in timeseries], dtype=np.float32)
except:
import ipdb; ipdb.set_trace()
lines = [np.nan_to_num(d[key]) for d in timeseries]
lines_x = [np.asarray(d[x_axis]) for d in timeseries]
alphas = [0.2/np.sqrt(len(lines)) for l in lines]
lines += [np.nan_to_num(np.nanmean(data, 0))]
alphas += [1.]
lines_x += [np.asarray(timeseries[0][x_axis][:max_timesteps])]
for alpha, y, x in zip(alphas, lines, lines_x):
axes[i, j].plot(x, y, color=color, alpha=alpha)
axes[i, j].set_title(key)
plt.show()
plt.close()