-
Notifications
You must be signed in to change notification settings - Fork 8
/
path_handler.py
85 lines (72 loc) · 2.57 KB
/
path_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import pathlib
from hydra import utils
def model_path(cfg, folder="saved"):
root = pathlib.Path(os.path.join(utils.get_original_cwd(), folder))
filename = f"{cfg.dataset}"
# Dataset-specific keys
if cfg.dataset in ["AddProblem", "CopyMemory"]:
filename += f"_seqlen_{cfg.dataset_params.seq_length}"
if cfg.dataset in ["CopyMemory"]:
filename += f"_memsize_{cfg.dataset_params.memory_size}"
elif cfg.dataset in ["MNIST"]:
filename += "_perm_{}".format(
cfg.permuted,
)
elif cfg.dataset in ["CharTrajectories", "SpeechCommands"]:
if cfg.dataset in ["SpeechCommands"]:
filename += "_mfcc_{}".format(
cfg.mfcc,
)
if (cfg.dataset in ["SpeechCommands"] and not cfg.mfcc) or cfg.dataset in [
"CharTrajectories"
]:
filename += "_srtr_{}_drop_{}".format(
cfg.sr_train,
cfg.drop_rate,
)
filename += "_augm_{}".format(cfg.augment)
# Model-specific keys
filename += "_model_{}_blcks_{}_nohid_{}".format(
cfg.model,
cfg.no_blocks,
cfg.no_hidden,
)
filename += "_kernnohid_{}_kernact_{}".format(
cfg.kernelnet_no_hidden,
cfg.kernelnet_activation_function,
)
if cfg.kernelnet_activation_function == "Sine":
filename += "_kernomega0_{}".format(round(cfg.kernelnet_omega_0, 2))
else:
filename += "_kernnorm_{}".format(cfg.kernelnet_norm_type)
# elif config.model in ["BFCNN", "TCN"]:
# filename += "_kernsize_{}".format(config.cnn_kernel_size)
# Optimization arguments
filename += "_bs_{}_optim_{}_lr_{}_ep_{}_dpin_{}_dp_{}_wd_{}_seed_{}_sched_{}_schdec_{}".format(
cfg.batch_size,
cfg.optimizer,
cfg.lr,
cfg.epochs,
cfg.dropout_in,
cfg.dropout,
cfg.weight_decay,
cfg.seed,
cfg.scheduler,
cfg.sched_decay_factor,
)
if cfg.scheduler == "plateau":
filename += "_pat_{}".format(cfg.sched_patience)
else:
filename += "_schsteps_{}".format(cfg.sched_decay_steps)
# Comment
if cfg.comment != "":
filename += "_comment_{}".format(cfg.comment)
# Add correct termination
filename += ".pt"
# Check if directory exists and warn the user if the it exists and train is used.
os.makedirs(root, exist_ok=True)
path = root / filename
cfg.path = str(path)
if cfg.train and path.exists():
print("WARNING! The model exists in directory and will be overwritten")