-
Notifications
You must be signed in to change notification settings - Fork 1
/
export.py
131 lines (111 loc) · 6.13 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from copy import deepcopy
from typing import Union, Tuple
import pickle
import SimpleITK as sitk
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from preprocessing import get_lowres_axis, get_do_separate_z, resample_data_or_seg
isfile = os.path.isfile
def save_pickle(obj, file: str, mode: str = 'wb') -> None:
with open(file, mode) as f:
pickle.dump(obj, f)
def save_segmentation_nifti_from_softmax(segmentation_softmax: Union[str, np.ndarray], out_fname: str,
properties_dict: dict, order: int = 1,
region_class_order: Tuple[Tuple[int]] = None,
seg_postprogess_fn: callable = None, seg_postprocess_args: tuple = None,
resampled_npz_fname: str = None,
non_postprocessed_fname: str = None, force_separate_z: bool = None,
interpolation_order_z: int = 0, verbose: bool = True, task_type=None,
primary_task=0, out_softmax_fname=None):
if task_type is None:
task_type = ["CLASSIFICATION"]
if verbose: print("force_separate_z:", force_separate_z, "interpolation order:", order)
if isinstance(segmentation_softmax, str):
assert isfile(segmentation_softmax), "If isinstance(segmentation_softmax, str) then " \
"isfile(segmentation_softmax) must be True"
del_file = deepcopy(segmentation_softmax)
segmentation_softmax = np.load(segmentation_softmax)
os.remove(del_file)
current_shape = segmentation_softmax.shape
shape_original_after_cropping = properties_dict.get('size_after_cropping')
shape_original_before_cropping = properties_dict.get('original_size_of_raw_data')
if np.any([i != j for i, j in zip(np.array(current_shape[1:]), np.array(shape_original_after_cropping))]):
if force_separate_z is None:
if get_do_separate_z(properties_dict.get('original_spacing')):
do_separate_z = True
lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
elif get_do_separate_z(properties_dict.get('spacing_after_resampling')):
do_separate_z = True
lowres_axis = get_lowres_axis(properties_dict.get('spacing_after_resampling'))
else:
do_separate_z = False
lowres_axis = None
else:
do_separate_z = force_separate_z
if do_separate_z:
lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
else:
lowres_axis = None
if lowres_axis is not None and len(lowres_axis) != 1:
do_separate_z = False
if verbose: print("separate z:", do_separate_z, "lowres axis", lowres_axis)
seg_old_spacing = resample_data_or_seg(segmentation_softmax, shape_original_after_cropping, is_seg=False,
axis=lowres_axis, order=order, do_separate_z=do_separate_z,
order_z=interpolation_order_z)
else:
if verbose: print("no resampling necessary")
seg_old_spacing = segmentation_softmax
softmax_resampled = np.copy(seg_old_spacing)
if resampled_npz_fname is not None:
np.savez_compressed(resampled_npz_fname, softmax=seg_old_spacing.astype(np.float16))
if region_class_order is not None:
properties_dict['regions_class_order'] = region_class_order
save_pickle(properties_dict, resampled_npz_fname[:-4] + ".pkl")
regression = False
if task_type[primary_task] == "REGRESSION":
regression = True
if region_class_order is None and not regression:
seg_old_spacing = seg_old_spacing.argmax(0)
elif regression:
seg_old_spacing = np.squeeze(seg_old_spacing)
else:
seg_old_spacing_final = np.zeros(seg_old_spacing.shape[1:])
for i, c in enumerate(region_class_order):
seg_old_spacing_final[seg_old_spacing[i] > 0.5] = c
seg_old_spacing = seg_old_spacing_final
bbox = properties_dict.get('crop_bbox')
if bbox is not None:
seg_old_size = np.zeros(shape_original_before_cropping)
for c in range(3):
bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
seg_old_size[bbox[0][0]:bbox[0][1],
bbox[1][0]:bbox[1][1],
bbox[2][0]:bbox[2][1]] = seg_old_spacing
else:
seg_old_size = seg_old_spacing
if seg_postprogess_fn is not None:
seg_old_size_postprocessed = seg_postprogess_fn(np.copy(seg_old_size), *seg_postprocess_args)
else:
seg_old_size_postprocessed = seg_old_size
if regression:
seg_resized_itk = sitk.GetImageFromArray(seg_old_size_postprocessed.astype(np.float))
else:
seg_resized_itk = sitk.GetImageFromArray(seg_old_size_postprocessed.astype(np.uint8))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, out_fname)
if (non_postprocessed_fname is not None) and (seg_postprogess_fn is not None):
seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, non_postprocessed_fname)
if out_softmax_fname is not None:
seg_resized_itk = sitk.GetImageFromArray(np.transpose(softmax_resampled, (1, 2, 3, 0)).astype(np.float32))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, out_softmax_fname)
return softmax_resampled