diff --git a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml index 38080f1a..250f18d6 100644 --- a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml @@ -4,7 +4,7 @@ trainer: init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -54,7 +54,7 @@ model: class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS init_args: in_features: ${oc.env:IN_FEATURES, 384} - num_classes: &NUM_CLASSES 118 + num_classes: &NUM_CLASSES 37 criterion: class_path: eva.vision.losses.DiceLoss init_args: @@ -142,6 +142,7 @@ data: val: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS + shuffle: true test: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS diff --git a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml index 2671ec40..8f584f50 100644 --- a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml @@ -4,7 +4,7 @@ trainer: init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -44,10 +44,10 @@ model: out_indices: ${oc.env:OUT_INDICES, 1} model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage init_args: in_features: ${oc.env:IN_FEATURES, 384} - num_classes: &NUM_CLASSES 118 + num_classes: &NUM_CLASSES 37 criterion: class_path: eva.vision.losses.DiceLoss init_args: @@ -120,6 +120,7 @@ data: val: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS + shuffle: true test: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS diff --git a/src/eva/vision/data/datasets/segmentation/_total_segmentator.py b/src/eva/vision/data/datasets/segmentation/_total_segmentator.py new file mode 100644 index 00000000..55ff0766 --- /dev/null +++ b/src/eva/vision/data/datasets/segmentation/_total_segmentator.py @@ -0,0 +1,91 @@ +"""Utils for TotalSegmentator dataset classes.""" + +from typing import Dict + +reduced_class_mappings: Dict[str, str] = { + # Abdominal Organs + "spleen": "spleen", + "kidney_right": "kidney", + "kidney_left": "kidney", + "gallbladder": "gallbladder", + "liver": "liver", + "stomach": "stomach", + "pancreas": "pancreas", + "small_bowel": "small_bowel", + "duodenum": "duodenum", + "colon": "colon", + # Endocrine System + "adrenal_gland_right": "adrenal_gland", + "adrenal_gland_left": "adrenal_gland", + "thyroid_gland": "thyroid_gland", + # Respiratory System + "lung_upper_lobe_left": "lungs", + "lung_lower_lobe_left": "lungs", + "lung_upper_lobe_right": "lungs", + "lung_middle_lobe_right": "lungs", + "lung_lower_lobe_right": "lungs", + "trachea": "trachea", + "esophagus": "esophagus", + # Urogenital System + "urinary_bladder": "urogenital_system", + "prostate": "urogenital_system", + "kidney_cyst_left": "kidney_cyst", + "kidney_cyst_right": "kidney_cyst", + # Vertebral Column + **{f"vertebrae_{v}": "vertebrae" for v in ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]}, + **{f"vertebrae_{v}": "vertebrae" for v in [f"T{i}" for i in range(1, 13)]}, + **{f"vertebrae_{v}": "vertebrae" for v in [f"L{i}" for i in range(1, 6)]}, + "vertebrae_S1": "vertebrae", + "sacrum": "sacral_spine", + # Cardiovascular System + "heart": "heart", + "aorta": "aorta", + "pulmonary_vein": "veins", + "brachiocephalic_trunk": "arteries", + "subclavian_artery_right": "arteries", + "subclavian_artery_left": "arteries", + "common_carotid_artery_right": "arteries", + "common_carotid_artery_left": "arteries", + "brachiocephalic_vein_left": "veins", + "brachiocephalic_vein_right": "veins", + "atrial_appendage_left": "atrial_appendage", + "superior_vena_cava": "veins", + "inferior_vena_cava": "veins", + "portal_vein_and_splenic_vein": "veins", + "iliac_artery_left": "arteries", + "iliac_artery_right": "arteries", + "iliac_vena_left": "veins", + "iliac_vena_right": "veins", + # Upper Extremity Bones + "humerus_left": "humerus", + "humerus_right": "humerus", + "scapula_left": "scapula", + "scapula_right": "scapula", + "clavicula_left": "clavicula", + "clavicula_right": "clavicula", + # Lower Extremity Bones + "femur_left": "femur", + "femur_right": "femur", + "hip_left": "hip", + "hip_right": "hip", + # Muscles + "gluteus_maximus_left": "gluteus", + "gluteus_maximus_right": "gluteus", + "gluteus_medius_left": "gluteus", + "gluteus_medius_right": "gluteus", + "gluteus_minimus_left": "gluteus", + "gluteus_minimus_right": "gluteus", + "autochthon_left": "autochthon", + "autochthon_right": "autochthon", + "iliopsoas_left": "iliopsoas", + "iliopsoas_right": "iliopsoas", + # Central Nervous System + "brain": "brain", + "spinal_cord": "spinal_cord", + # Skull and Thoracic Cage + "skull": "skull", + **{f"rib_left_{i}": "ribs" for i in range(1, 13)}, + **{f"rib_right_{i}": "ribs" for i in range(1, 13)}, + "costal_cartilages": "ribs", + "sternum": "sternum", +} diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index fefba203..53cc0c5f 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -1,7 +1,9 @@ """TotalSegmentator 2D segmentation dataset class.""" import functools +import hashlib import os +import re from glob import glob from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Tuple @@ -16,7 +18,7 @@ from eva.core.utils import io as core_io from eva.core.utils import multiprocessing from eva.vision.data.datasets import _validators, structs -from eva.vision.data.datasets.segmentation import base +from eva.vision.data.datasets.segmentation import _total_segmentator, base from eva.vision.utils import io @@ -66,6 +68,7 @@ def __init__( version: Literal["small", "full"] | None = "full", download: bool = False, classes: List[str] | None = None, + class_mappings: Dict[str, str] | None = _total_segmentator.reduced_class_mappings, optimize_mask_loading: bool = True, decompress: bool = True, num_workers: int = 10, @@ -85,6 +88,8 @@ def __init__( exist yet on disk. classes: Whether to configure the dataset with a subset of classes. If `None`, it will use all of them. + class_mappings: A dictionary that maps the original class names to a + reduced set of classes. If `None`, it will use the original classes. optimize_mask_loading: Whether to pre-process the segmentation masks in order to optimize the loading time. In the `setup` method, it will reformat the binary one-hot masks to a semantic mask and store @@ -109,11 +114,10 @@ def __init__( self._optimize_mask_loading = optimize_mask_loading self._decompress = decompress self._num_workers = num_workers + self._class_mappings = class_mappings - if self._optimize_mask_loading and self._classes is not None: - raise ValueError( - "To use customize classes please set the optimize_mask_loading to `False`." - ) + if self._classes and self._class_mappings: + raise ValueError("Both 'classes' and 'class_mappings' cannot be set at the same time.") self._samples_dirs: List[str] = [] self._indices: List[Tuple[int, int]] = [] @@ -125,16 +129,21 @@ def get_filename(path: str) -> str: """Returns the filename from the full path.""" return os.path.basename(path).split(".")[0] - first_sample_labels = os.path.join( - self._root, self._samples_dirs[0], "segmentations", "*.nii.gz" - ) + first_sample_labels = os.path.join(self._root, "s0011", "segmentations", "*.nii.gz") all_classes = sorted(map(get_filename, glob(first_sample_labels))) if self._classes: is_subset = all(name in all_classes for name in self._classes) if not is_subset: - raise ValueError("Provided class names are not subset of the dataset onces.") - - return all_classes if self._classes is None else self._classes + raise ValueError("Provided class names are not subset of the original ones.") + classes = sorted(self._classes) + elif self._class_mappings: + is_subset = all(name in all_classes for name in self._class_mappings.keys()) + if not is_subset: + raise ValueError("Provided class names are not subset of the original ones.") + classes = sorted(set(self._class_mappings.values())) + else: + classes = all_classes + return ["background"] + classes @property @override @@ -145,6 +154,10 @@ def class_to_idx(self) -> Dict[str, int]: def _file_suffix(self) -> str: return "nii" if self._decompress else "nii.gz" + @functools.cached_property + def _classes_hash(self) -> str: + return hashlib.md5(str(self.classes).encode(), usedforsecurity=False).hexdigest() + @override def filename(self, index: int) -> str: sample_idx, _ = self._indices[index] @@ -170,15 +183,22 @@ def validate(self) -> None: if self._version is None or self._sample_every_n_slices is not None: return + if self._classes: + last_label = self._classes[-1] + n_classes = len(self._classes) + elif self._class_mappings: + classes = sorted(set(self._class_mappings.values())) + last_label = classes[-1] + n_classes = len(classes) + else: + last_label = "vertebrae_T9" + n_classes = 117 + _validators.check_dataset_integrity( self, length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0), - n_classes=len(self._classes) if self._classes else 117, - first_and_last_labels=( - (self._classes[0], self._classes[-1]) - if self._classes - else ("adrenal_gland_left", "vertebrae_T9") - ), + n_classes=n_classes + 1, + first_and_last_labels=("background", last_label), ) @override @@ -190,32 +210,31 @@ def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] image_path = self._get_image_path(sample_index) image_array = io.read_nifti(image_path, slice_index) - image_rgb_array = image_array.repeat(3, axis=2) - return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1)) + image_array = self._fix_orientation(image_array) + return tv_tensors.Image(image_array.copy().transpose(2, 0, 1)) @override def load_mask(self, index: int) -> tv_tensors.Mask: if self._optimize_mask_loading: - return self._load_semantic_label_mask(index) - return self._load_mask(index) + mask = self._load_semantic_label_mask(index) + else: + mask = self._load_mask(index) + mask = self._fix_orientation(mask) + return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore @override def load_metadata(self, index: int) -> Dict[str, Any]: _, slice_index = self._indices[index] return {"slice_index": slice_index} - def _load_mask(self, index: int) -> tv_tensors.Mask: + def _load_mask(self, index: int) -> npt.NDArray[Any]: sample_index, slice_index = self._indices[index] - semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index) - return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + return self._load_masks_as_semantic_label(sample_index, slice_index) - def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: + def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]: """Loads the segmentation mask from a semantic label NifTi file.""" sample_index, slice_index = self._indices[index] - masks_dir = self._get_masks_dir(sample_index) - filename = os.path.join(masks_dir, "semantic_labels", "masks.nii") - semantic_labels = io.read_nifti(filename, slice_index) - return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index) def _load_masks_as_semantic_label( self, sample_index: int, slice_index: int | None = None @@ -227,18 +246,39 @@ def _load_masks_as_semantic_label( slice_index: Whether to return only a specific slice. """ masks_dir = self._get_masks_dir(sample_index) - mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes] + classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:] + mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes] binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths] + + if self._class_mappings: + mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len( + self.classes[1:] + ) + for original_class, mapped_class in self._class_mappings.items(): + mapped_index = self.class_to_idx[mapped_class] - 1 + original_index = list(self._class_mappings.keys()).index(original_class) + mapped_binary_masks[mapped_index] = np.logical_or( + mapped_binary_masks[mapped_index], binary_masks[original_index] + ) + binary_masks = mapped_binary_masks + background_mask = np.zeros_like(binary_masks[0]) return np.argmax([background_mask] + binary_masks, axis=0) def _export_semantic_label_masks(self) -> None: """Exports the segmentation binary masks (one-hot) to semantic labels.""" + mask_classes_file = os.path.join(f"{self._get_optimized_masks_root()}/classes.txt") + if os.path.isfile(mask_classes_file): + with open(mask_classes_file, "r") as file: + if file.read() != str(self.classes): + raise ValueError( + "Optimized masks hash doesn't match the current classes or mappings." + ) + return + total_samples = len(self._samples_dirs) - masks_dirs = map(self._get_masks_dir, range(total_samples)) semantic_labels = [ - (index, os.path.join(directory, "semantic_labels", "masks.nii")) - for index, directory in enumerate(masks_dirs) + (index, self._get_optimized_masks_file(index)) for index in range(total_samples) ] to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels) @@ -255,6 +295,16 @@ def _process_mask(sample_index: Any, filename: str) -> None: return_results=False, ) + os.makedirs(os.path.dirname(mask_classes_file), exist_ok=True) + with open(mask_classes_file, "w") as file: + file.write(str(self.classes)) + + def _fix_orientation(self, array: npt.NDArray): + """Fixes orientation such that table is at the bottom & liver on the left.""" + array = np.rot90(array) + array = np.flip(array, axis=1) + return array + def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" sample_dir = self._samples_dirs[sample_index] @@ -265,10 +315,15 @@ def _get_masks_dir(self, sample_index: int) -> str: sample_dir = self._samples_dirs[sample_index] return os.path.join(self._root, sample_dir, "segmentations") - def _get_semantic_labels_filename(self, sample_index: int) -> str: + def _get_optimized_masks_root(self) -> str: + """Returns the directory of the optimized masks.""" + return os.path.join(self._root, f"processed/masks/{self._classes_hash}") + + def _get_optimized_masks_file(self, sample_index: int) -> str: """Returns the semantic label filename.""" - masks_dir = self._get_masks_dir(sample_index) - return os.path.join(masks_dir, "semantic_labels", "masks.nii") + return os.path.join( + f"{self._get_optimized_masks_root()}/{self._samples_dirs[sample_index]}/masks.nii" + ) def _get_number_of_slices_per_sample(self, sample_index: int) -> int: """Returns the total amount of slices of a sample.""" @@ -281,7 +336,7 @@ def _fetch_samples_dirs(self) -> List[str]: sample_filenames = [ filename for filename in os.listdir(self._root) - if os.path.isdir(os.path.join(self._root, filename)) + if os.path.isdir(os.path.join(self._root, filename)) and re.match(r"^s\d{4}$", filename) ] return sorted(sample_filenames) diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 06baffdd..3a86920b 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -2,17 +2,29 @@ import os import shutil -from typing import Literal +from typing import Dict, Literal +from unittest.mock import patch import pytest from torchvision import tv_tensors from eva.vision.data import datasets +_class_mappings = { + "aorta_small": "class_1", + "brain_small": "class_1", + "colon_small": "class_2", +} + @pytest.mark.parametrize( - "split, expected_length", - [("train", 6), ("val", 3), (None, 9)], + "split, expected_length, class_mappings, optimize_mask_loading", + [ + ("train", 6, None, False), + ("train", 6, None, True), + ("val", 3, None, False), + (None, 9, None, False), + ], ) def test_length( total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int @@ -22,11 +34,13 @@ def test_length( @pytest.mark.parametrize( - "split, index", + "split, index, class_mappings, optimize_mask_loading", [ - (None, 0), - ("train", 0), - ("val", 0), + (None, 0, None, False), + (None, 0, None, True), + ("train", 0, None, False), + ("val", 0, None, False), + ("train", 0, _class_mappings, False), ], ) def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: int) -> None: @@ -38,16 +52,62 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i # assert the format of the `image` and `mask` image, mask, metadata = sample assert isinstance(image, tv_tensors.Image) - assert image.shape == (3, 16, 16) + assert image.shape == (1, 16, 16) assert isinstance(mask, tv_tensors.Mask) assert mask.shape == (16, 16) assert isinstance(metadata, dict) assert "slice_index" in metadata + # check the number of classes with v.s. without class mappings + n_classes_expected = 3 if total_segmentator_dataset._class_mappings is not None else 4 + assert len(total_segmentator_dataset.classes) == n_classes_expected + + +@pytest.mark.parametrize( + "split, class_mappings, optimize_mask_loading", + [ + ("train", None, False), + ("train", None, True), + ("train", _class_mappings, True), + ], +) +def test_optimize_mask_loading(total_segmentator_dataset: datasets.TotalSegmentator2D): + """Tests the optimized mask loading.""" + optimize = total_segmentator_dataset._optimize_mask_loading is True + + if optimize: + expected_masks_dir = os.path.join( + total_segmentator_dataset._root, + f"processed/masks/{total_segmentator_dataset._classes_hash}", + ) + expected_classes_file = os.path.join(expected_masks_dir, "classes.txt") + assert os.path.isdir(expected_masks_dir) + assert os.path.isfile(expected_classes_file) + + with open(expected_classes_file, "r") as f: + assert f.read() == str(total_segmentator_dataset.classes) + + with ( + patch.object(total_segmentator_dataset, "_load_semantic_label_mask") as mock_load_optimized, + patch.object(total_segmentator_dataset, "_load_mask") as mock_load, + patch.object(total_segmentator_dataset, "_fix_orientation") as _, + ): + _ = total_segmentator_dataset.load_mask(0) + if optimize: + mock_load_optimized.assert_called_once_with(0) + mock_load.assert_not_called() + else: + mock_load.assert_called_once_with(0) + mock_load_optimized.assert_not_called() + @pytest.fixture(scope="function") def total_segmentator_dataset( - tmp_path: str, split: Literal["train", "val"] | None, assets_path: str + tmp_path: str, + split: Literal["train", "val"] | None, + assets_path: str, + class_mappings: Dict[str, str] | None, + optimize_mask_loading: bool, ) -> datasets.TotalSegmentator2D: """TotalSegmentator2D dataset fixture.""" dataset_dir = os.path.join( @@ -62,6 +122,8 @@ def total_segmentator_dataset( root=tmp_path, split=split, version=None, + class_mappings=class_mappings, + optimize_mask_loading=optimize_mask_loading, ) dataset.prepare_data() dataset.configure()