-
Notifications
You must be signed in to change notification settings - Fork 1
/
args.py
executable file
·68 lines (55 loc) · 3.17 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
def get_args():
parser = argparse.ArgumentParser(description='Demographic Aware Probabilistic Medical Knowledge Graph Embeddings of Electronic Medical Records')
# general
parser.add_argument('--seed', default=1234, type=int)
parser.add_argument('--no-cuda', action='store_true')
parser.add_argument('--cuda_device', default=0, type=int)
# data
parser.add_argument('--data_path', default='/data/medical_kg')
# experiments
parser.add_argument('--snapshots', default='experiments/snapshots', type=str)
parser.add_argument('--results_path', default='experiments/results', type=str)
parser.add_argument('--resume', default='experiments/snapshots', type=str)
parser.add_argument('--checkpoint_path', default='experiments/snapshots', type=str)
# model
parser.add_argument('--model', default='DARLING', choices=['TransE',
'TransH',
'TransR',
'TransD',
'PrTransE',
'PrTransH',
'DARLING'], type=str)
# task
parser.add_argument('--task', default='both', choices=['both',
'treatment_recommendation',
'medicine_recommendation'], type=str)
# model parameters
parser.add_argument('--emb_dim', default=100, type=int)
parser.add_argument('--dropout', default=1e-1, type=int)
parser.add_argument('--d_norm', default=2, type=int)
parser.add_argument('--gamma', default=1, type=int)
parser.add_argument('--target', default=-1, type=int)
parser.add_argument('--reduction', default='sum', choices=['none', 'mean', 'sum'], type=str)
# training
parser.add_argument('--lr', default=1e-3, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--start_epoch', default=0, type=int)
parser.add_argument('--valfreq', default=1, type=int)
parser.add_argument('--clip', default=5, type=int)
parser.add_argument('--batch_size', default=128, type=int)
# other
parser.add_argument('--negative_prob', default=1e-15, type=float)
parser.add_argument('--scaling_prob', default=1e-2, type=float)
args, argv = parser.parse_known_args()
if args.model in ['PrTransE', 'PrTransH']:
parser.add_argument('--demographic_aware', default=False, action='store_true')
parser.add_argument('--prob_embedding', default=True, action='store_true')
elif args.model in ['DARLING']:
parser.add_argument('--demographic_aware', default=True, action='store_true')
parser.add_argument('--prob_embedding', default=True, action='store_true')
else:
parser.add_argument('--demographic_aware', default=False, action='store_true')
parser.add_argument('--prob_embedding', default=False, action='store_true')
parser.parse_args(argv, namespace=args)
return args