-
Notifications
You must be signed in to change notification settings - Fork 7
/
pipeline.py
160 lines (100 loc) · 5.9 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# SYSTEM
import os
import sys
import logging
# EXTERNALS
import yaml
# UTILS
from utils.pipeline_utils import parse_args
# CLASSIFICATION PIPELINE IMPORTS
sys.path.append('.')
from MetricLearning.src.preprocess_with_dir import preprocess
from MetricLearning.src.metric_learning_adjacent import build_graphs as build_doublet_graphs
from GraphLearning import build_triplets as build_triplet_graphs
from Seeding import seed
from Labelling import label
# TRAINING PIPELINE IMPORTS
from MetricLearning.src.metric_learning_adjacent.preprocess import preprocess_stage_1, preprocess_stage_2
from MetricLearning.src.metric_learning_adjacent.train_embed import train_embed
from MetricLearning.src.metric_learning_adjacent.train_filter import train_filter
from GraphLearning import train as train_gnn
def classify(args):
# Handle the logic flow of forced data re-building (i.e. ignoring existing data)
force_order = {"all": 0, "preprocess": 1, "build_doublets": 2, "build_triplets": 3, "seed": 4, "label": 5}
force = [False]*len(force_order)
if args.force is not None:
force[force_order[args.force]:] = [True]*(len(force_order)-force_order[args.force])
# BUILD / INFERENCE PIPELINE
# -------------------------------
## PREPROCESS DATA - Calculate cell features, select barrel, no noise, construct event-by-event files
# Checks for existence of preprocessed dataset.
print('--------------------- \n Preprocessing... \n--------------------')
args.data_storage_path = os.path.join(args.data_storage_path, "classify", args.name)
preprocess_data_path, feature_names = preprocess.main(args, force=force[force_order["preprocess"]])
if args.stage == 'preprocess': return
## BUILD DOUBLET GRAPHS - Apply embedding, run filter inference, post-process, construct event doublet files
# Checks for existence of doublet graphs.
print('--------------------- \n Building doublets... \n--------------------')
build_doublet_graphs.main(args, force=force[force_order["build_doublets"]])
if args.stage == 'build_doublets': return
## BUILD TRIPLET GRAPHS - Apply doublet GNN model to doublet graphs for classification, construct triplets using score & cut
# Checks for existence of triplet graphs
print('--------------------- \n Building triplets... \n--------------------')
build_triplet_graphs.main(args, force=force[force_order["build_triplets"]])
if args.stage == 'build_triplets': return
## CLASSIFY TRIPLET GRAPHS FOR SEEDS - Classify, apply cut to triplet scores, print out into specified format
# Checks for existence of seed dataset
if args.stage == "seed":
print('--------------------- \n Seeding... \n--------------------')
seed.main(args, force=force[force_order["seed"]])
return
## CLASSIFY TRIPLET GRAPHS FOR TRACK LABELS - Classify, convert to doublets, construct sparse undirected matrix, run DBSCAN, print out into specified format
if args.stage == "label":
print('--------------------- \n Labelling... \n--------------------')
label.main(args, force=force[force_order["label"]])
return
def train(args):
# Handle the logic flow of forced data re-building (i.e. ignoring existing data)
force_order = {"all": 0, "preprocess": 1, "train_embedding": 2, "train_filter": 3, "train_doublets": 4, "train_triplets": 5}
force = [False]*len(force_order)
if args.force is not None:
force[force_order[args.force]:] = [True]*(len(force_order)-force_order[args.force])
# TRAIN PIPELINE
# ------------------------------
## PREPROCESS DATA
print('--------------------- \n Preprocessing... \n--------------------')
args.data_storage_path, args.artifact_storage_path = os.path.join(args.data_storage_path, "train", args.name), os.path.join(args.data_storage_path, "artifacts", args.name)
preprocess_data_path, feature_names = preprocess.main(args, force=force[force_order["preprocess"]])
print('--------------------- \n Training embedding... \n--------------------')
## BUILD EMBEDDING DATA
preprocess_stage_1.main(args, force=force[force_order["train_embedding"]])
## TRAIN EMBEDDING
train_embed.main(args, force=force[force_order["train_embedding"]])
if args.stage == 'train_embedding': return
print('--------------------- \n Training filter... \n--------------------')
## BUILD FILTERING DATA
preprocess_stage_2.main(args, force=force[force_order["train_filter"]])
## TRAIN FILTERING
train_filter.main(args, force=force[force_order["train_filter"]])
if args.stage == 'train_filtering': return
print('--------------------- \n Training doublets... \n--------------------')
## BUILD DOUBLET GRAPHS
build_doublet_graphs.main(args, force=force[force_order["train_doublets"]])
## TRAIN DOUBLET GNN
train_gnn.main(args, force=force[force_order["train_doublets"]], gnn='doublet')
if args.stage == 'train_doublets': return
print('--------------------- \n Training triplets... \n--------------------')
## BUILD TRIPLET GRAPHS
build_triplet_graphs.main(args, force=force[force_order["train_triplets"]])
## TRAIN TRIPLET GNN
train_gnn.main(args, force=force[force_order["train_triplets"]], gnn='triplet')
if __name__ == "__main__":
args = parse_args()
# Define verbosity for whole pipeline
log_format = '%(asctime)s %(levelname)s %(message)s'
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(level=log_level, format=log_format)
if args.stage in ["seed", "label", "preprocess", "build_doublets", "build_triplets"]:
classify(args)
elif args.stage in ["train", "train_embedding", "train_filter", "train_doublets", "train_triplets"]:
train(args)