Skip to content

Commit

Permalink
ensuring data consistency between output images and original header d…
Browse files Browse the repository at this point in the history
…ata types
  • Loading branch information
yw7 committed Jun 4, 2024
1 parent c3fa11b commit 8c51f37
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
22 changes: 20 additions & 2 deletions src/totalsegmri/utils/generate_resampled_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def generate_resampled_images(
):

image = nib.load(image_path)

# Get the data type of the image
image_data_dtype = np.asanyarray(image.dataobj).dtype
image_header_dtype = getattr(np, image.get_data_dtype().name)

image_data = image.get_fdata().astype(np.float64)

if segs_path:
Expand Down Expand Up @@ -204,13 +209,26 @@ def generate_resampled_images(
))
output_image_data = subject.image.data.numpy()[0, ...].astype(np.float64)

# Rescale the image to the output data type if necessary
# code from https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/6.3/spinalcordtoolbox/image.py#L1217
if "int" in np.dtype(image_header_dtype).name:
# get min/max from output type
min_out = np.iinfo(image_header_dtype).min
max_out = np.iinfo(image_header_dtype).max
min_in = output_image_data.min()
max_in = output_image_data.max()
if (min_in < min_out) or (max_in > max_out):
data_rescaled = output_image_data * (max_out - min_out) / (max_in - min_in)
output_image_data = data_rescaled - (data_rescaled.min() - min_out)

output_image_path = output_images_path / image_path.relative_to(images_path).parent / image_path.name.replace(f'{image_suffix}.nii.gz', f'{output_image_suffix}.nii.gz')

# Make sure output directory exists and save with original image dtype
# Make sure output directory exists and save with original header image dtype
output_image_path.parent.mkdir(parents=True, exist_ok=True)
output_image = nib.Nifti1Image(output_image_data.astype(np.asanyarray(image.dataobj).dtype), subject.image.affine, image.header)
output_image = nib.Nifti1Image(output_image_data.astype(image_header_dtype), subject.image.affine, image.header)
output_image.set_qform(subject.image.affine)
output_image.set_sform(subject.image.affine)
output_image.set_data_dtype(image_header_dtype)
nib.save(output_image, output_image_path)

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion src/totalsegmri/utils/transform_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def transform_norm(

# Rescale the image to the output data type if necessary
# code from https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/6.3/spinalcordtoolbox/image.py#L1217
if image_header_dtype != image_data_dtype and "int" in np.dtype(image_header_dtype).name:
if "int" in np.dtype(image_header_dtype).name:
# get min/max from output type
min_out = np.iinfo(image_header_dtype).min
max_out = np.iinfo(image_header_dtype).max
Expand Down

0 comments on commit 8c51f37

Please sign in to comment.