-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_rewards.py
107 lines (79 loc) · 3.24 KB
/
plot_rewards.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pandas as pd
import glob
import numpy as np
sns.set_style("whitegrid", {'axes.grid' : True,
'axes.edgecolor':'black'})
fig = plt.figure(figsize=(10, 7))
MAX_RANGE=4000 # total number of episodes to include in the plot
SAVE_NAME='figures/train_relu.png' # where to save
PLOT_REGION=False # plot boundaries or not
TITLE='Average reward over 5 runs using FTA' # title of the plot
COLORS=1 # which color plate to use (1 or 2 defined below (2 is for comparison))
PATH1='.rewards/train/relu' # path to the saved rewards 1
PATH2='' # path to the saved rewards 2 for comparison plots
if PATH2 != '':
aux_paths = [PATH1+'/', PATH2+'/']
else:
aux_paths = [PATH1+'/']
labels = ['', 'Scratch']
plt.clf()
if COLORS==1:
colors=['blue', 'red', 'green', 'yellow', 'magenta', 'cyan', 'skyblue', 'tomato', 'lime', 'khaki', 'orchid', 'teal']
else:
colors=['blue', 'slateblue', 'darkviolet', 'mediumblue', 'dodgerblue', 'skyblue', 'red', 'tomato', 'orange', 'coral', 'brown', 'chocolate']
# plot
for j, aux_path in enumerate(aux_paths):
for index, path in enumerate(['no_aux', 'ir', 'rp', 'vvf1', 'vvf5', 'sf']):
if path == 'no_aux':
name = 'No AUX'
elif path == 'ir':
name = 'IR'
elif path == 'rp':
name = 'RP'
elif path == 'vvf1':
name = 'VVF-1'
elif path == 'vvf5':
name = 'VVF-5'
elif path == 'sf':
name = 'SF'
files = glob.glob(aux_path + path+'/*')
rewards = []
max_len = 0
for file in files:
with open(file, 'rb') as fp:
t = pickle.load(fp)
if max_len < len(t):
max_len = len(t)
rewards.append(t)
fp.close()
rewards_t = torch.zeros((max_len), dtype=torch.float)
for i in range(len(rewards)):
tt = torch.ones((max_len), dtype=torch.float)
r = torch.tensor(rewards[i], dtype=torch.float)
tt[:r.shape[0]] = r
rewards_t += tt
rewards_t = rewards_t / len(rewards)
means = rewards_t[0:MAX_RANGE].unfold(0, 50, 1).mean(1).view(-1)
mins = rewards_t[0:MAX_RANGE].unfold(0, 50, 1).min(1)[0]
maxs = rewards_t[0:MAX_RANGE].unfold(0, 50, 1).max(1)[0]
means = torch.cat((torch.zeros(49), means))
mins = torch.cat((torch.zeros(49), mins))
maxs = torch.cat((torch.zeros(49), maxs))
sns.lineplot(means.numpy(), label=name+' '+labels[j], color=colors[j*6+index])
sns.lineplot(mins.numpy(), alpha=0.0)
c = sns.lineplot(maxs.numpy(), alpha=0.0)
line = c.get_lines()
if PLOT_REGION:
plt.fill_between(line[j*18+index*3+0].get_xdata(), line[j*18+index*3+1].get_ydata(), line[j*18+index*3+2].get_ydata(), color=colors[j*6+index], alpha=.15, label=name+' observed range')
plt.xlabel('Episode', fontsize=14)
plt.ylabel('Average Reward', fontsize=14)
plt.title(TITLE, fontsize=16)
plt.legend(
frameon=True, fancybox=True, loc="best", prop={'size': 12})
sns.despine()
plt.tight_layout()
plt.savefig(SAVE_NAME)