-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train_Conv_Networks.py
43 lines (29 loc) · 2.3 KB
/
Train_Conv_Networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import torch
from pathlib import Path
from functions import train_conv_network
parser = argparse.ArgumentParser()
parser.add_argument('--debug', type=bool, required=True, default=False)
args = parser.parse_args()
debug = args.debug
# "Fast_16_8_4":{"modelConv_type":"Fast","en_elayers_dim":16,"temb_dim":8,"conv_dof":4},
# "Chebyshev_16_8_8":{"modelConv_type":"Chebyshev","en_elayers_dim":16,"temb_dim":8,"conv_dof":8},
# "ReLU_16_8_4":{"modelConv_type":"ReLU","en_elayers_dim":16,"temb_dim":8,"conv_dof":4},
# "Bottleneck_16_8_2":{"modelConv_type":"Bottleneck","en_elayers_dim":16,"temb_dim":8,"conv_dof":2},
# "Bernstein_16_8_1":{"modelConv_type":"Bernstein","en_elayers_dim":16,"temb_dim":8,"conv_dof":1},
# "Jacobi_16_8_1":{"modelConv_type":"Jacobi","en_elayers_dim":16,"temb_dim":8,"conv_dof":1},
# "Wav_16_8_4":{"modelConv_type":"Wav","en_elayers_dim":16,"temb_dim":8,"conv_dof":4}
conv_models_dict = {"Wav_16_8_4":{"modelConv_type":"Wav","en_elayers_dim":16,"temb_dim":8,"conv_dof":4}}
sel_conv_models = conv_models_dict.keys()
if __name__ == "__main__":
for conv_model in sel_conv_models:
torch.cuda.empty_cache()
if debug:
status = train_conv_network(conv_models_dict[conv_model]['modelConv_type'], conv_models_dict[conv_model]['en_elayers_dim'], conv_models_dict[conv_model]['temb_dim'], conv_models_dict[conv_model]['conv_dof'], n_iter=10)
else:
try:
status = status = train_conv_network(conv_models_dict[conv_model]['modelConv_type'], conv_models_dict[conv_model]['en_elayers_dim'], conv_models_dict[conv_model]['temb_dim'], conv_models_dict[conv_model]['conv_dof'], n_iter=10)
print(f"{status} - Energy: {conv_models_dict[conv_model]['modelConv_type']}_{conv_models_dict[conv_model]['en_elayers_dim']}_{conv_models_dict[conv_model]['temb_dim']}_{conv_models_dict[conv_model]['conv_dof']}, iter{conv_models_dict[conv_model]['conv_model_iter']}")
except:
print(f"#----- Falha: Energy: {conv_models_dict[conv_model]['modelConv_type']}_{conv_models_dict[conv_model]['en_elayers_dim']}_{conv_models_dict[conv_model]['temb_dim']}_{conv_models_dict[conv_model]['conv_dof']}, iter{conv_models_dict[conv_model]['conv_model_iter']} -----#")
print("Done!")