forked from GMvandeVen/brain-inspired-replay
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrep_compare_MNIST_hyperParams.py
235 lines (190 loc) · 8.61 KB
/
rep_compare_MNIST_hyperParams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
import os
import numpy as np
from param_stamp import get_param_stamp_from_args
import options
from visual import plt as my_plt
from matplotlib.pyplot import get_cmap
import main_cl
## Parameter-values to compare
lamda_list = [1., 10., 100., 1000., 10000., 100000., 1000000., 10000000., 100000000., 1000000000., 10000000000.,
100000000000.]
gamma_list = [1.]
c_list = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1., 5., 10., 50., 100., 500., 1000., 5000., 10000.]
xdg_list = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
## Repulsion hyperparameters...
lamda_diff_list = [1., 10., 100., 1000., 10000., 100000., 1000000., 10000000., 100000000., 1000000000., 10000000000.,
100000000000.]
# Whether to use KL or JS divergence...
kl_js_list = ['kl', 'js']
# Selectrion factors...
f_list = [1.25, 1.5, 2, 2.5, 3, 3.5, 4, 5]
## Function for specifying input-options and organizing / checking them
def handle_inputs():
# Set indicator-dictionary for correctly retrieving / checking input options
kwargs = {'single_task': False, 'only_MNIST': True, 'generative': False, 'compare_code': 'hyper'}
# Define input options
parser = options.define_args(filename="_compare_MNIST_hyperParams",
description='Compare hyperparameters of EWC, online EWC, SI and XdG on different '
'"scenarios" of splitMNIST.')
parser = options.add_general_options(parser, **kwargs)
parser = options.add_eval_options(parser, **kwargs)
parser = options.add_task_options(parser, **kwargs)
parser = options.add_model_options(parser, **kwargs)
parser = options.add_train_options(parser, **kwargs)
parser = options.add_allocation_options(parser, **kwargs)
parser.add_argument('--no-online', action='store_true', help="don't do online EWC")
# Parse and process (i.e., set defaults for unselected options) options
args = parser.parse_args()
options.set_defaults(args, **kwargs)
return args
def get_result(args):
# -get param-stamp
param_stamp = get_param_stamp_from_args(args)
# -check whether already run, and if not do so
if os.path.isfile('{}/prec-{}.txt'.format(args.r_dir, param_stamp)):
print("{}: already run".format(param_stamp))
else:
print("{}: ...running...".format(param_stamp))
main_cl.run(args)
# -get average precision
fileName = '{}/prec-{}.txt'.format(args.r_dir, param_stamp)
file = open(fileName)
ave = float(file.readline())
file.close()
# -return it
return ave
if __name__ == '__main__':
## Load input-arguments & set default values
args = handle_inputs()
## Add default arguments (will be different for different runs)
args.ewc = False
args.online = False
args.si = False
args.xdg = False
## If we shouldn't do online-EWC
if args.no_online:
gamma_list = []
## If needed, create plotting directory
if not os.path.isdir(args.p_dir):
os.mkdir(args.p_dir)
#-------------------------------------------------------------------------------------------------#
#--------------------------#
#----- RUN ALL MODELS -----#
#--------------------------#
## Baselline
BASE = get_result(args)
## SI
SI = {}
args.si = True
for si_c in c_list:
args.si_c = si_c
SI[si_c] = get_result(args)
args.si = False
## XdG
if args.scenario=="task":
XDG = {}
args.xdg = True
for xdg in xdg_list:
args.xdg_prop = xdg
XDG[xdg] = get_result(args)
args.xdg_prop = 0.
#-------------------------------------------------------------------------------------------------#
#-----------------------------------------#
#----- COLLECT DATA & PRINT ON SCREEN-----#
#-----------------------------------------#
ext_c_list = [0] + c_list
ext_lambda_list = [0] + lamda_list
ext_xdg_list = [0] + xdg_list
print("\n")
###---EWC + online EWC---###
# -collect data
ave_prec_ewc = [BASE] + [EWC[ewc_lambda] for ewc_lambda in lamda_list]
ave_prec_per_lambda = [ave_prec_ewc]
for gamma in gamma_list:
ave_prec_temp = [BASE] + [OEWC[gamma][ewc_lambda] for ewc_lambda in lamda_list]
ave_prec_per_lambda.append(ave_prec_temp)
# -print on screen
print("\n\nELASTIC WEIGHT CONSOLIDATION (EWC)")
print(" param-list (lambda): {}".format(ext_lambda_list))
print(" {}".format(ave_prec_ewc))
print("---> lambda = {} -- {}".format(ext_lambda_list[np.argmax(ave_prec_ewc)], np.max(ave_prec_ewc)))
if len(gamma_list) > 0:
print("\n\nONLINE EWC")
print(" param-list (lambda): {}".format(ext_lambda_list))
curr_max = 0
for gamma in gamma_list:
ave_prec_temp = [BASE] + [OEWC[gamma][ewc_lambda] for ewc_lambda in lamda_list]
print(" (gamma={}): {}".format(gamma, ave_prec_temp))
if np.max(ave_prec_temp) > curr_max:
gamam_max = gamma
lamda_max = ext_lambda_list[np.argmax(ave_prec_temp)]
curr_max = np.max(ave_prec_temp)
print("---> gamma = {} - lambda = {} -- {}".format(gamam_max, lamda_max, curr_max))
###---SI---###
# -collect data
ave_prec_si = [BASE] + [SI[c] for c in c_list]
# -print on screen
print("\n\nSYNAPTIC INTELLIGENCE (SI)")
print(" param list (si_c): {}".format(ext_c_list))
print(" {}".format(ave_prec_si))
print("---> si_c = {} -- {}".format(ext_c_list[np.argmax(ave_prec_si)], np.max(ave_prec_si)))
###---XdG---###
if args.scenario=="task":
# -collect data
ave_prec_xdg = [BASE] + [XDG[c] for c in xdg_list]
# -print on screen
print("\n\nCONTEXT-DEPENDENT GATING (XDG))")
print(" param list (gating_prop): {}".format(ext_xdg_list))
print(" {}".format(ave_prec_xdg))
print("---> gating_prop = {} -- {}".format(ext_xdg_list[np.argmax(ave_prec_xdg)], np.max(ave_prec_xdg)))
print('\n')
#-------------------------------------------------------------------------------------------------#
#--------------------#
#----- PLOTTING -----#
#--------------------#
# name for plot
plot_name = "hyperParams-{}{}-{}".format(args.experiment, args.tasks, args.scenario)
scheme = "incremental {} learning".format(args.scenario)
title = "{} - {}".format(args.experiment, scheme)
ylabel = "Test accuracy (after all tasks)"
# calculate y-axes (to have equal for EWC, SI and XdG)
full_list = [item for sublist in ave_prec_per_lambda for item in sublist] + ave_prec_si
if args.scenario=="task":
full_list += ave_prec_xdg
miny = np.min(full_list)
maxy = np.max(full_list)
marginy = 0.1*(maxy-miny)
# open pdf
pp = my_plt.open_pdf("{}/{}.pdf".format(args.p_dir, plot_name))
figure_list = []
###---EWC + online EWC---###
# - select colors
colors = ["darkgreen"]
colors += get_cmap('Greens')(np.linspace(0.7, 0.3, len(gamma_list))).tolist()
# - make plot (line plot - only average)
figure = my_plt.plot_lines(ave_prec_per_lambda, x_axes=ext_lambda_list, ylabel=ylabel,
line_names=["EWC"] + ["Online EWC - gamma = {}".format(gamma) for gamma in gamma_list],
title=title, x_log=True, xlabel="EWC: lambda log-scale)",
ylim=(miny-marginy, maxy+marginy),
with_dots=True, colors=colors, h_line=BASE, h_label="None")
figure_list.append(figure)
###---SI---###
figure = my_plt.plot_lines([ave_prec_si], x_axes=ext_c_list, ylabel=ylabel, line_names=["SI"],
colors=["yellowgreen"], title=title, x_log=True, xlabel="SI: c (log-scale)", with_dots=True,
ylim=(miny-marginy, maxy+marginy), h_line=BASE, h_label="None")
figure_list.append(figure)
###---XdG---###
if args.scenario=="task":
figure = my_plt.plot_lines([ave_prec_xdg], x_axes=ext_xdg_list, ylabel=ylabel,
line_names=["XdG"], colors=["deepskyblue"], ylim=(miny-marginy, maxy+marginy),
title=title, x_log=False, xlabel="XdG: % of nodes gated",
with_dots=True, h_line=BASE, h_label="None")
figure_list.append(figure)
# add figures to pdf
for figure in figure_list:
pp.savefig(figure)
# close the pdf
pp.close()
# Print name of generated plot on screen
print("\nGenerated plot: {}/{}.pdf\n".format(args.p_dir, plot_name))