diff --git a/README.md b/README.md index 29079a5..906d251 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ # TotalSpineSeg -TotalSpineSeg is a tool for automatic instance segmentation of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images. It is robust to various MRI contrasts, acquisition orientations, and resolutions. The model used in TotalSpineSeg is based on [nnUNet](https://github.com/MIC-DKFZ/nnUNet) as the backbone for training and inference. +TotalSpineSeg is a tool for automatic instance segmentation of all vertebrae, intervertebral discs (IVDs), spinal cord, and spinal canal in MRI images. It is robust to various MRI contrasts, acquisition orientations, and resolutions. The model used in TotalSpineSeg is based on [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) as the backbone for training and inference. If you use this model, please cite our work: > Warszawer Y, Molinier N, Valošek J, Shirbint E, Benveniste PL, Achiron A, Eshaghi A and Cohen-Adad J. _Fully Automatic Vertebrae and Spinal Cord Segmentation Using a Hybrid Approach Combining nnU-Net and Iterative Algorithm_. Proceedings of the 32th Annual Meeting of ISMRM. 2024 -Please also cite nnUNet since our work is heavily based on it: +Please also cite nnU-Net since our work is heavily based on it: > Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211. ![Thumbnail](https://github.com/user-attachments/assets/2c1b1ff9-daaa-479f-8d21-01a66b9c9cb4) @@ -151,11 +151,13 @@ Please ensure that your system meets these requirements before proceeding with t 1. Run the model on a folder containing the images in .nii.gz format, or on a single .nii.gz file: ```bash - totalspineseg INPUT OUTPUT_FOLDER [--step1] + totalspineseg INPUT OUTPUT_FOLDER [--step1] [--iso] ``` This will process the images in INPUT or the single image and save the results in OUTPUT_FOLDER. If you haven't trained the model, the script will automatically download the pre-trained models from the GitHub release. + **Important Note:** By default, the output segmentations are resampled back to the input image space. If you prefer to obtain the outputs in the model's original 1mm isotropic resolution, especially useful for visualization purposes, we strongly recommend using the `--iso` argument. + Additionally, you can use the `--step1` parameter to run only the step 1 model, which outputs a single label for all vertebrae, including the sacrum. For more options, you can use the `--help` parameter: @@ -209,8 +211,8 @@ In this example, main images are placed in the `images` folder and corresponding To use localizer-based labeling: ```bash -# Process localizer images -totalspineseg localizers localizers_output +# Process localizer images. We recommend using the --iso flag for the localizer to ensure consistent resolution. +totalspineseg localizers localizers_output --iso # Run model on main images using localizer output totalspineseg images output --loc localizers_output/step2_output --suffix _T2w --loc-suffix _T1w diff --git a/totalspineseg/inference.py b/totalspineseg/inference.py index acf214a..4956491 100644 --- a/totalspineseg/inference.py +++ b/totalspineseg/inference.py @@ -35,6 +35,10 @@ def main(): 'output', type=Path, help='The output folder where the model outputs will be stored.' ) + parser.add_argument( + '--iso', action="store_true", default=False, + help='Use isotropic output as output by the model instead of resampling output to the input, defaults to false.' + ) parser.add_argument( '--loc', '-l', type=Path, default=None, help=' '.join(f''' @@ -87,6 +91,7 @@ def main(): # Get the command-line argument values input_path = args.input output_path = args.output + output_iso = args.iso loc_path = args.loc suffix = args.suffix loc_suffix = args.loc_suffix @@ -146,6 +151,7 @@ def main(): Running TotalSpineSeg with the following parameters: input = "{input_path}" output = "{output_path}" + iso = {output_iso} loc = "{loc_path}" suffix = {suffix} loc_suffix = "{loc_suffix}" @@ -698,5 +704,61 @@ def main(): }, ) + if not output_iso: + if not quiet: print('\n' 'Resampling step1_output to the input images space:') + transform_seg2image_mp( + input_path, + output_path / 'step1_output', + output_path / 'step1_output', + image_suffix = '', + overwrite=True, + max_workers=max_workers, + quiet=quiet, + ) + if not quiet: print('\n' 'Resampling step1_cord to the input images space:') + transform_seg2image_mp( + input_path, + output_path / 'step1_cord', + output_path / 'step1_cord', + image_suffix = '', + interpolation = 'linear', + overwrite=True, + max_workers=max_workers, + quiet=quiet, + ) + if not quiet: print('\n' 'Resampling step1_canal to the input images space:') + transform_seg2image_mp( + input_path, + output_path / 'step1_canal', + output_path / 'step1_canal', + image_suffix = '', + interpolation = 'linear', + overwrite=True, + max_workers=max_workers, + quiet=quiet, + ) + if not quiet: print('\n' 'Resampling step1_levels to the input images space:') + transform_seg2image_mp( + input_path, + output_path / 'step1_levels', + output_path / 'step1_levels', + image_suffix = '', + interpolation = 'label', + overwrite=True, + max_workers=max_workers, + quiet=quiet, + ) + if not step1_only: + if not quiet: print('\n' 'Resampling step2_output to the input images space:') + transform_seg2image_mp( + input_path, + output_path / 'step2_output', + output_path / 'step2_output', + image_suffix = '', + overwrite=True, + max_workers=max_workers, + quiet=quiet, + ) + if __name__ == '__main__': main() \ No newline at end of file diff --git a/totalspineseg/utils/transform_seg2image.py b/totalspineseg/utils/transform_seg2image.py index 9932b6f..7914cd3 100644 --- a/totalspineseg/utils/transform_seg2image.py +++ b/totalspineseg/utils/transform_seg2image.py @@ -195,11 +195,18 @@ def _transform_seg2image( ) # Ensure correct segmentation dtype, affine and header - output_seg = nib.Nifti1Image( - np.asanyarray(output_seg.dataobj).round().astype(np.uint8), - output_seg.affine, output_seg.header - ) - output_seg.set_data_dtype(np.uint8) + if interpolation == 'linear': + output_seg = nib.Nifti1Image( + np.asanyarray(output_seg.dataobj).astype(np.float32), + output_seg.affine, output_seg.header + ) + output_seg.set_data_dtype(np.float32) + else: + output_seg = nib.Nifti1Image( + np.asanyarray(output_seg.dataobj).round().astype(np.uint8), + output_seg.affine, output_seg.header + ) + output_seg.set_data_dtype(np.uint8) output_seg.set_qform(output_seg.affine) output_seg.set_sform(output_seg.affine) @@ -231,7 +238,11 @@ def transform_seg2image( ''' image_data = np.asanyarray(image.dataobj).astype(np.float64) image_affine = image.affine.copy() - seg_data = np.asanyarray(seg.dataobj).round().astype(np.uint8) + seg_data = np.asanyarray(seg.dataobj) + if interpolation == 'linear': + seg_data = seg_data.astype(np.float32) + else: + seg_data = seg_data.round().astype(np.uint8) seg_affine = seg.affine.copy() # Dilations size - the maximum of factor by which the image zooms are larger than the segmentation zooms @@ -261,7 +272,11 @@ def transform_seg2image( # Resample the segmentation to the image space tio_output_seg = tio.Resample(tio_img)(tio_seg) - output_seg_data = tio_output_seg.data.numpy()[0, ...].astype(np.uint8) + output_seg_data = tio_output_seg.data.numpy()[0, ...] + if interpolation == 'linear': + output_seg_data = output_seg_data.astype(np.float32) + else: + output_seg_data = output_seg_data.round().astype(np.uint8) if interpolation == 'label': # Initialize the output segmentation to zeros