-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_experiment.py
61 lines (54 loc) · 1.7 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from ml_core import *
loss_types = [
'mse',
'top_1',
'top_10',
'static_weight',
'dynamic_weight',
'static_subgraph',
'dynamic_subgraph'
]
# loading the experiments
with open('experiment.json', 'r') as file:
json_dict = json.load(file)
compression_ratio = json_dict['compression_ratio']
frame_length = json_dict['frame_length']
num_message_passings = json_dict['num_message_passings']
rotation = json_dict['rotation']
maxK = json_dict['maxK']
with_transform_stats = load_transform_stats('with_transform_stats')
exp_num = int(sys.argv[1])
modelTools = model_tools(
frame_length=frame_length[exp_num],
transformation=with_transform_stats,
rotation = rotation[exp_num], maxK=maxK[exp_num]
)
modelTools.prepare_ds(num_message_passings=0)
for loss_type in loss_types:
modelTools.initialize_training(
compression_ratio=compression_ratio[exp_num],
exp_num=exp_num,
loss_type=loss_type,
model_type='dae'
)
modelTools.fit()
modelTools.store_data()
modelTools.prepare_ds(num_message_passings=num_message_passings[exp_num])
for loss_type in loss_types:
modelTools.load_data(loss_type, exp_num, 'dae')
modelTools.initialize_training(
compression_ratio=compression_ratio[exp_num],
exp_num=exp_num,
loss_type=loss_type,
model_type='combined'
)
modelTools.fit()
modelTools.store_data()
modelTools.initialize_training(
compression_ratio=compression_ratio[exp_num],
exp_num=exp_num,
loss_type=loss_type,
model_type='gnn',
)
modelTools.fit()
modelTools.store_data()