forked from selimsef/xview2_solution
-
Notifications
You must be signed in to change notification settings - Fork 4
/
ensemble.py
76 lines (59 loc) · 2.18 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
import os
from os import cpu_count
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import skimage
import skimage.io
from multiprocessing.pool import Pool
import numpy as np
import cv2
cv2.setNumThreads(0)
def average_strategy(images):
return np.average(images, axis=0)
def hard_voting(images):
rounded = np.round(images / 255.)
return np.round(np.sum(rounded, axis=0) / images.shape[0]) * 255.
def ensemble_image(params):
file, dirs, ensembling_dir, strategy = params
images = []
for dir in dirs:
file_path = os.path.join(dir, file)
try:
images.append(skimage.io.imread(file_path))
except:
images.append(skimage.io.imread(file_path))
images = np.array(images)
if strategy == 'average':
ensembled = average_strategy(images)
elif strategy == 'hard_voting':
ensembled = hard_voting(images)
else:
raise ValueError('Unknown ensembling strategy')
skimage.io.imsave(os.path.join(ensembling_dir, file), ensembled.astype(np.uint8))
def ensemble(dirs, strategy, ensembling_dir, n_threads, mode="localization"):
files = os.listdir(dirs[0])
params = []
for file in files:
if file.endswith(mode + ".png"):
params.append((file, dirs, ensembling_dir, strategy))
pool = Pool(n_threads)
pool.map(ensemble_image, params)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Ensemble masks")
arg = parser.add_argument
arg('--ensembling_cpu_threads', type=int, default=cpu_count())
arg('--ensembling_dir', type=str,)
arg('--strategy', type=str, default='average')
arg('--folds_dir', type=str, )
arg('--dirs_to_ensemble', nargs='+')
arg('--mode', default='localization')
args = parser.parse_args()
folds_dir = args.folds_dir
dirs = [os.path.join(folds_dir, d) for d in args.dirs_to_ensemble]
for d in dirs:
if not os.path.exists(d):
raise ValueError(d + " doesn't exist")
os.makedirs(args.ensembling_dir, exist_ok=True)
ensemble(dirs, args.strategy, args.ensembling_dir, args.ensembling_cpu_threads, mode=args.mode)