-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
122 lines (97 loc) · 3.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from os import path
import logging
from omegaconf import OmegaConf
import hydra
import hashlib
import json
import wandb
from experiment import Experiment
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def get_model_name(config):
masked_copy = OmegaConf.masked_copy(
config, ["datasets", "model", "trainer", "optimizer", "new_max_ent", "run_num"]
)
encoded = json.dumps(OmegaConf.to_container(masked_copy), sort_keys=True).encode()
hash_obj = hashlib.md5()
hash_obj.update(encoded)
hash_obj.update(f"seed: {config.seed}".encode())
model_hash = str(hash_obj.hexdigest())
if len(config.datasets) > 1:
dataset_name = "joint"
else:
dataset_name = list(config.datasets.keys())[0]
if dataset_name == "litbank":
cross_val_split = config.datasets[dataset_name].cross_val_split
dataset_name += f"_cv_{cross_val_split}"
model_name = f"{dataset_name}_{model_hash}"
return model_name
def main_train(config):
if config.paths.model_name is None:
model_name = get_model_name(config)
else:
model_name = config.paths.model_name
config.paths.model_dir = path.join(
config.paths.base_model_dir, config.paths.model_name_prefix + model_name
)
config.paths.best_model_dir = path.join(config.paths.model_dir, "best")
for model_dir in [config.paths.model_dir, config.paths.best_model_dir]:
if not path.exists(model_dir):
os.makedirs(model_dir)
if config.paths.model_path is None:
config.paths.model_path = path.abspath(
path.join(config.paths.model_dir, config.paths.model_filename)
)
config.paths.best_model_path = path.abspath(
path.join(config.paths.best_model_dir, config.paths.model_filename)
)
if config.paths.best_model_path is None and (config.paths.model_path is not None):
config.paths.best_model_path = config.paths.model_path
# Dump config file
config_file = path.join(config.paths.model_dir, "config.json")
with open(config_file, "w") as f:
f.write(json.dumps(OmegaConf.to_container(config), indent=4, sort_keys=True))
return model_name
def main_eval(config):
if config.paths.model_dir is None:
raise ValueError
best_model_dir = path.join(config.paths.model_dir, "best")
if path.exists(best_model_dir):
config.paths.best_model_dir = best_model_dir
else:
config.paths.best_model_dir = config.paths.model_dir
config.paths.best_model_path = path.abspath(
path.join(config.paths.best_model_dir, config.paths.model_filename)
)
@hydra.main(config_path="conf", config_name="config")
def main(config):
print(config)
if config.train:
model_name = main_train(config)
else:
main_eval(config)
model_name = path.basename(path.normpath(config.paths.model_dir))
# Strip prefix
if model_name.startswith(config.paths.model_name_prefix):
model_name = model_name[len(config.paths.model_name_prefix) :]
if config.use_wandb:
# Wandb Initialization
try:
wandb.init(
id=model_name,
project="Coreference",
config=dict(config),
resume=True,
)
except:
# Turn off wandb
config.use_wandb = False
logger.info(f"Model name: {model_name}")
Experiment(config)
if __name__ == "__main__":
import sys
sys.argv.append(f"hydra.run.dir={path.dirname(path.realpath(__file__))}")
sys.argv.append("hydra/job_logging=none")
main()