-
Notifications
You must be signed in to change notification settings - Fork 0
/
light_external_ov.py
144 lines (112 loc) · 6.12 KB
/
light_external_ov.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
from datetime import datetime
from pathlib import Path
from warnings import filterwarnings
import lightning.pytorch as pl
import pandas as pd
import torch
import yaml
from tqdm import tqdm
from dataset import TCGA_Program_Dataset
from datasets_manager import TCGA_Balanced_Datasets_Manager, TCGA_Datasets_Manager
from lit_models import LitFullModel
from model import Classifier, Feature_Extractor, Graph_And_Clinical_Feature_Extractor, Task_Classifier
from utils import config_add_subdict_key, get_logger, override_n_genes, set_random_seed, setup_logging
from external_lightningdatamodule import ExternalDataModule
SEED = 1126
set_random_seed(SEED)
def main():
# Select a config file.
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help='Path to the config file.', required=True)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
override_n_genes(config) # For multi-task graph models.
config_name = Path(args.config).stem
# Setup logging.
setup_logging(log_path := f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/')
logger = get_logger(config_name)
logger.info(f'Using Random Seed {SEED} for this experiment')
get_logger('lightning.pytorch.accelerators.cuda', log_level='WARNING') # Disable cuda logging.
filterwarnings('ignore', r'.*Skipping val loop.*') # Disable val loop warning.
# Create dataset manager.
#here use torch lightning DS
data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}
if 'TCGA_Balanced_Datasets_Manager' == config['datasets_manager']['type']:
manager_train = TCGA_Balanced_Datasets_Manager(datasets=data, config=config_add_subdict_key(config))
#manager_test = TCGA_Balanced_Datasets_Manager(datasets=external_testing_data, config=config_add_subdict_key(config))
else:
manager_train = TCGA_Datasets_Manager(datasets=data, config=config_add_subdict_key(config))
#manager_test = TCGA_Datasets_Manager(datasets=external_testing_data, config=config_add_subdict_key(config))
# Cross validation, adapt the code for tensting on the external dataset batches
# still need to create he proper shuffling of the data
for key, values in manager_train['TCGA_BLC']['dataloaders'].items():
if isinstance(key, int) and config['cross_validation']:
models, optimizers = create_models_and_optimizers(config)
lit_model = LitFullModel(models, optimizers, config)
trainer = pl.Trainer( # Create sub-folders for each fold.
default_root_dir=log_path,
max_epochs=config['max_epochs'],
log_every_n_steps=1,
enable_model_summary=False,
enable_checkpointing=False,
)
#trainer.fit(lit_model, train_dataloaders=values['train'], val_dataloaders=values['valid']) #the validation is failing
trainer.fit(lit_model, train_dataloaders=values['train'])
#trainer.test(lit_model, dataloaders=test, verbose=True)
elif key == 'train':
train = values
elif key == 'test':
test = manager_train['TCGA_BLC']['dataloaders']['test']
# Train the final model.
models, optimizers = create_models_and_optimizers(config)
lit_model = LitFullModel(models, optimizers, config)
trainer = pl.Trainer(
default_root_dir=log_path,
max_epochs=config['max_epochs'],
enable_progress_bar=False,
log_every_n_steps=1,
logger=False,
)
trainer.fit(lit_model, train_dataloaders=train)
# Test the final model.
external_testing_data = {'TCGA_BLC': TCGA_Program_Dataset(**config['external_datasets']) }
manager_test = TCGA_Balanced_Datasets_Manager(datasets=external_testing_data, config=config_add_subdict_key(config))
test = manager_test['TCGA_BLC']['dataloaders']['test']
# get baseline for AUPRC in the test set
bootstrap_results = []
for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'):
bootstrap_results.append(trainer.test(lit_model, dataloaders=test, verbose=False)[0])
bootstrap_results = pd.DataFrame.from_records(bootstrap_results)
for key, value in bootstrap_results.describe().loc[['mean', 'std']].to_dict().items():
logger.info(f'| {key.ljust(10).upper()} | {value["mean"]:.5f} ± {value["std"]:.5f} |')
def create_models_and_optimizers(config: dict):
models: dict[str, torch.nn.Module] = {}
optimizers: dict[str, torch.optim.Optimizer] = {}
# Setup models. Do not use getattr() for better IDE support.
for model_name, kargs in config['models'].items():
if model_name == 'Graph_And_Clinical_Feature_Extractor':
models['feat_ext'] = Graph_And_Clinical_Feature_Extractor(**kargs)
elif model_name == 'Feature_Extractor':
models['feat_ext'] = Feature_Extractor(**kargs)
elif model_name == 'Task_Classifier':
models['clf'] = Task_Classifier(**kargs)
elif model_name == 'Classifier':
models['clf'] = Classifier(**kargs)
else:
raise ValueError(f'Unknown model type: {model_name}')
# Setup optimizers. If the key is 'all', the optimizer will be applied to all models.
for key, optim_dict in config['optimizers'].items():
opt_name = next(iter(optim_dict))
if key == 'all':
params = [param for model in models.values() for param in model.parameters()]
optimizers[key] = getattr(torch.optim, opt_name)(params, **optim_dict[opt_name])
else:
optimizers[key] = getattr(torch.optim, opt_name)(models[key].parameters(), **optim_dict[opt_name])
# Add models' structure to config for logging. TODO: Prettify.
for model_name, torch_model in models.items():
config[f'model.{model_name}'] = str(torch_model)
return models, optimizers
if __name__ == '__main__':
main()