-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
103 lines (82 loc) · 2.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#----> pytorch imports
import torch
#----> general imports
import pandas as pd
import numpy as np
import pdb
import os
from timeit import default_timer as timer
from datasets.dataset_survival import SurvivalDatasetFactory
from utils.core_utils import _train_val
from utils.file_utils import _save_pkl
from utils.general_utils import _get_start_end, _prepare_for_experiment
from utils.process_args import _process_args
def main(args):
#----> prep for 5 fold cv study
folds = _get_start_end(args)
#----> storing the val and test cindex for 5 fold cv
all_val_cindex = []
all_val_cindex_ipcw = []
all_val_BS = []
all_val_IBS = []
all_val_iauc = []
all_val_loss = []
for i in folds:
datasets = args.dataset_factory.return_splits(
args,
csv_path='{}/splits_{}.csv'.format(args.split_dir, i),
fold=i
)
print("Created train and val datasets for fold {}".format(i))
results, (val_cindex, val_cindex_ipcw, val_BS, val_IBS, val_iauc, total_loss) = _train_val(datasets, i, args)
all_val_cindex.append(val_cindex)
all_val_cindex_ipcw.append(val_cindex_ipcw)
all_val_BS.append(val_BS)
all_val_IBS.append(val_IBS)
all_val_iauc.append(val_iauc)
all_val_loss.append(total_loss)
#write results to pkl
filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
print("Saving results...")
_save_pkl(filename, results)
final_df = pd.DataFrame({
'folds': folds,
'val_cindex': all_val_cindex,
'val_cindex_ipcw': all_val_cindex_ipcw,
'val_IBS': all_val_IBS,
'val_iauc': all_val_iauc,
"val_loss": all_val_loss,
'val_BS': all_val_BS,
})
if len(folds) != args.k:
save_name = 'summary_partial_{}_{}.csv'.format(start, end)
else:
save_name = 'summary.csv'
final_df.to_csv(os.path.join(args.results_dir, save_name))
if __name__ == "__main__":
start = timer()
#----> read the args
args = _process_args()
#----> Prep
args = _prepare_for_experiment(args)
#----> create dataset factory
args.dataset_factory = SurvivalDatasetFactory(
study=args.study,
label_file=args.label_file,
omics_dir=args.omics_dir,
seed=args.seed,
print_info=True,
n_bins=args.n_classes,
label_col=args.label_col,
eps=1e-6,
num_patches=args.num_patches,
is_mcat = True if "coattn" in args.modality else False,
is_survpath = True if args.modality == "survpath" else False,
type_of_pathway=args.type_of_path)
#---> perform the experiment
results = main(args)
#---> stop timer and print
end = timer()
print("finished!")
print("end script")
print('Script Time: %f seconds' % (end - start))