-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
55 lines (43 loc) · 1.95 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import wandb
import sys
import os
import subprocess
from utils import *
def main(args, model_params):
train_files = {
TRANSE : 'models/embeddings/transe/train_transe.py',
PGPR : 'models/PGPR/train_agent.py',
CAFE : 'models/CAFE/train_neural_symbol.py',
UCPR : 'models/UCPR/train.py',
KGAT : 'models/knowledge_aware/KGAT/main.py',
CKE : 'models/knowledge_aware/CKE/main.py',
CFKG : 'models/knowledge_aware/CFKG/main.py',
BPRMF : 'models/matrix_factorization/BPRMF/main.py',
NFM : 'models/matrix_factorization/NFM/main.py',
FM : 'models/matrix_factorization/FM/main.py',
}
assert args.model in train_files, 'Error, given model name {args.model} not found in available models'
ensure_dataset_name(args.dataset)
TRAIN_FILE_NAME = train_files[args.model]
CMD = ["python3", os.path.basename(TRAIN_FILE_NAME) , "--dataset" , args.dataset]
if args.wandb:
CMD.append('--wandb')
CMD.extend( ['--wandb_entity', args.wandb_entity ] )
if len(model_params):
CMD.extend(model_params)
subprocess.call(CMD, cwd=os.path.dirname(TRAIN_FILE_NAME) )
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=LFM1M, help='One of {ml1m, lfm1m}')
parser.add_argument('--model', type=str, default=PGPR, help='Model to train: {pgpr, cafe, ucpr, cke, cfkg, kgat}')
parser.add_argument("--wandb", action="store_true", help="If passed, will log to Weights and Biases.")
parser.add_argument(
"--wandb_entity",
required="--wandb" in sys.argv,
type=str,
help="Entity name to push to the wandb logged data, in case args.wandb is specified.",
)
#parser.add_argument('rest', nargs=argparse.REMAINDER)
args, model_params = parser.parse_known_args()
main(args, model_params)