Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce classes of TotalSegmentator2D dataset #709

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ trainer:
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down Expand Up @@ -54,7 +54,7 @@ model:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 118
num_classes: &NUM_CLASSES 37
criterion:
class_path: eva.vision.losses.DiceLoss
init_args:
Expand Down Expand Up @@ -142,6 +142,7 @@ data:
val:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
shuffle: true
test:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ trainer:
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down Expand Up @@ -44,10 +44,10 @@ model:
out_indices: ${oc.env:OUT_INDICES, 1}
model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null}
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage
init_args:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 118
num_classes: &NUM_CLASSES 37
criterion:
class_path: eva.vision.losses.DiceLoss
init_args:
Expand Down Expand Up @@ -120,6 +120,7 @@ data:
val:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
shuffle: true
test:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
91 changes: 91 additions & 0 deletions src/eva/vision/data/datasets/segmentation/_total_segmentator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Utils for TotalSegmentator dataset classes."""

from typing import Dict

reduced_class_mappings: Dict[str, str] = {
# Abdominal Organs
"spleen": "spleen",
"kidney_right": "kidney",
"kidney_left": "kidney",
"gallbladder": "gallbladder",
"liver": "liver",
"stomach": "stomach",
"pancreas": "pancreas",
"small_bowel": "small_bowel",
"duodenum": "duodenum",
"colon": "colon",
# Endocrine System
"adrenal_gland_right": "adrenal_gland",
"adrenal_gland_left": "adrenal_gland",
"thyroid_gland": "thyroid_gland",
# Respiratory System
"lung_upper_lobe_left": "lungs",
"lung_lower_lobe_left": "lungs",
"lung_upper_lobe_right": "lungs",
"lung_middle_lobe_right": "lungs",
"lung_lower_lobe_right": "lungs",
"trachea": "trachea",
"esophagus": "esophagus",
# Urogenital System
"urinary_bladder": "urogenital_system",
"prostate": "urogenital_system",
"kidney_cyst_left": "kidney_cyst",
"kidney_cyst_right": "kidney_cyst",
# Vertebral Column
**{f"vertebrae_{v}": "vertebrae" for v in ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]},
**{f"vertebrae_{v}": "vertebrae" for v in [f"T{i}" for i in range(1, 13)]},
**{f"vertebrae_{v}": "vertebrae" for v in [f"L{i}" for i in range(1, 6)]},
"vertebrae_S1": "vertebrae",
"sacrum": "sacral_spine",
# Cardiovascular System
"heart": "heart",
"aorta": "aorta",
"pulmonary_vein": "veins",
"brachiocephalic_trunk": "arteries",
"subclavian_artery_right": "arteries",
"subclavian_artery_left": "arteries",
"common_carotid_artery_right": "arteries",
"common_carotid_artery_left": "arteries",
"brachiocephalic_vein_left": "veins",
"brachiocephalic_vein_right": "veins",
"atrial_appendage_left": "atrial_appendage",
"superior_vena_cava": "veins",
"inferior_vena_cava": "veins",
"portal_vein_and_splenic_vein": "veins",
"iliac_artery_left": "arteries",
"iliac_artery_right": "arteries",
"iliac_vena_left": "veins",
"iliac_vena_right": "veins",
# Upper Extremity Bones
"humerus_left": "humerus",
"humerus_right": "humerus",
"scapula_left": "scapula",
"scapula_right": "scapula",
"clavicula_left": "clavicula",
"clavicula_right": "clavicula",
# Lower Extremity Bones
"femur_left": "femur",
"femur_right": "femur",
"hip_left": "hip",
"hip_right": "hip",
# Muscles
"gluteus_maximus_left": "gluteus",
"gluteus_maximus_right": "gluteus",
"gluteus_medius_left": "gluteus",
"gluteus_medius_right": "gluteus",
"gluteus_minimus_left": "gluteus",
"gluteus_minimus_right": "gluteus",
"autochthon_left": "autochthon",
"autochthon_right": "autochthon",
"iliopsoas_left": "iliopsoas",
"iliopsoas_right": "iliopsoas",
# Central Nervous System
"brain": "brain",
"spinal_cord": "spinal_cord",
# Skull and Thoracic Cage
"skull": "skull",
**{f"rib_left_{i}": "ribs" for i in range(1, 13)},
**{f"rib_right_{i}": "ribs" for i in range(1, 13)},
"costal_cartilages": "ribs",
"sternum": "sternum",
}
129 changes: 92 additions & 37 deletions src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""TotalSegmentator 2D segmentation dataset class."""

import functools
import hashlib
import os
import re
from glob import glob
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Tuple
Expand All @@ -16,7 +18,7 @@
from eva.core.utils import io as core_io
from eva.core.utils import multiprocessing
from eva.vision.data.datasets import _validators, structs
from eva.vision.data.datasets.segmentation import base
from eva.vision.data.datasets.segmentation import _total_segmentator, base
from eva.vision.utils import io


Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(
version: Literal["small", "full"] | None = "full",
download: bool = False,
classes: List[str] | None = None,
class_mappings: Dict[str, str] | None = _total_segmentator.reduced_class_mappings,
optimize_mask_loading: bool = True,
decompress: bool = True,
num_workers: int = 10,
Expand All @@ -85,6 +88,8 @@ def __init__(
exist yet on disk.
classes: Whether to configure the dataset with a subset of classes.
If `None`, it will use all of them.
class_mappings: A dictionary that maps the original class names to a
reduced set of classes. If `None`, it will use the original classes.
optimize_mask_loading: Whether to pre-process the segmentation masks
in order to optimize the loading time. In the `setup` method, it
will reformat the binary one-hot masks to a semantic mask and store
Expand All @@ -109,11 +114,10 @@ def __init__(
self._optimize_mask_loading = optimize_mask_loading
self._decompress = decompress
self._num_workers = num_workers
self._class_mappings = class_mappings

if self._optimize_mask_loading and self._classes is not None:
raise ValueError(
"To use customize classes please set the optimize_mask_loading to `False`."
)
if self._classes and self._class_mappings:
raise ValueError("Both 'classes' and 'class_mappings' cannot be set at the same time.")

self._samples_dirs: List[str] = []
self._indices: List[Tuple[int, int]] = []
Expand All @@ -125,16 +129,21 @@ def get_filename(path: str) -> str:
"""Returns the filename from the full path."""
return os.path.basename(path).split(".")[0]

first_sample_labels = os.path.join(
self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
)
first_sample_labels = os.path.join(self._root, "s0011", "segmentations", "*.nii.gz")
all_classes = sorted(map(get_filename, glob(first_sample_labels)))
if self._classes:
is_subset = all(name in all_classes for name in self._classes)
if not is_subset:
raise ValueError("Provided class names are not subset of the dataset onces.")

return all_classes if self._classes is None else self._classes
raise ValueError("Provided class names are not subset of the original ones.")
classes = sorted(self._classes)
elif self._class_mappings:
is_subset = all(name in all_classes for name in self._class_mappings.keys())
if not is_subset:
raise ValueError("Provided class names are not subset of the original ones.")
classes = sorted(set(self._class_mappings.values()))
else:
classes = all_classes
return ["background"] + classes

@property
@override
Expand All @@ -145,6 +154,10 @@ def class_to_idx(self) -> Dict[str, int]:
def _file_suffix(self) -> str:
return "nii" if self._decompress else "nii.gz"

@functools.cached_property
def _classes_hash(self) -> str:
return hashlib.md5(str(self.classes).encode(), usedforsecurity=False).hexdigest()

@override
def filename(self, index: int) -> str:
sample_idx, _ = self._indices[index]
Expand All @@ -170,15 +183,22 @@ def validate(self) -> None:
if self._version is None or self._sample_every_n_slices is not None:
return

if self._classes:
last_label = self._classes[-1]
n_classes = len(self._classes)
elif self._class_mappings:
classes = sorted(set(self._class_mappings.values()))
last_label = classes[-1]
n_classes = len(classes)
else:
last_label = "vertebrae_T9"
n_classes = 117

_validators.check_dataset_integrity(
self,
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
n_classes=len(self._classes) if self._classes else 117,
first_and_last_labels=(
(self._classes[0], self._classes[-1])
if self._classes
else ("adrenal_gland_left", "vertebrae_T9")
),
n_classes=n_classes + 1,
first_and_last_labels=("background", last_label),
)

@override
Expand All @@ -190,32 +210,31 @@ def load_image(self, index: int) -> tv_tensors.Image:
sample_index, slice_index = self._indices[index]
image_path = self._get_image_path(sample_index)
image_array = io.read_nifti(image_path, slice_index)
image_rgb_array = image_array.repeat(3, axis=2)
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
image_array = self._fix_orientation(image_array)
return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))

@override
def load_mask(self, index: int) -> tv_tensors.Mask:
if self._optimize_mask_loading:
return self._load_semantic_label_mask(index)
return self._load_mask(index)
mask = self._load_semantic_label_mask(index)
else:
mask = self._load_mask(index)
mask = self._fix_orientation(mask)
return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
_, slice_index = self._indices[index]
return {"slice_index": slice_index}

def _load_mask(self, index: int) -> tv_tensors.Mask:
def _load_mask(self, index: int) -> npt.NDArray[Any]:
sample_index, slice_index = self._indices[index]
semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
return self._load_masks_as_semantic_label(sample_index, slice_index)

def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
"""Loads the segmentation mask from a semantic label NifTi file."""
sample_index, slice_index = self._indices[index]
masks_dir = self._get_masks_dir(sample_index)
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii")
semantic_labels = io.read_nifti(filename, slice_index)
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)

def _load_masks_as_semantic_label(
self, sample_index: int, slice_index: int | None = None
Expand All @@ -227,18 +246,39 @@ def _load_masks_as_semantic_label(
slice_index: Whether to return only a specific slice.
"""
masks_dir = self._get_masks_dir(sample_index)
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes]
classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]

if self._class_mappings:
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
self.classes[1:]
)
for original_class, mapped_class in self._class_mappings.items():
mapped_index = self.class_to_idx[mapped_class] - 1
original_index = list(self._class_mappings.keys()).index(original_class)
mapped_binary_masks[mapped_index] = np.logical_or(
mapped_binary_masks[mapped_index], binary_masks[original_index]
)
binary_masks = mapped_binary_masks

background_mask = np.zeros_like(binary_masks[0])
return np.argmax([background_mask] + binary_masks, axis=0)

def _export_semantic_label_masks(self) -> None:
"""Exports the segmentation binary masks (one-hot) to semantic labels."""
mask_classes_file = os.path.join(f"{self._get_optimized_masks_root()}/classes.txt")
if os.path.isfile(mask_classes_file):
with open(mask_classes_file, "r") as file:
if file.read() != str(self.classes):
raise ValueError(
"Optimized masks hash doesn't match the current classes or mappings."
)
return

total_samples = len(self._samples_dirs)
masks_dirs = map(self._get_masks_dir, range(total_samples))
semantic_labels = [
(index, os.path.join(directory, "semantic_labels", "masks.nii"))
for index, directory in enumerate(masks_dirs)
(index, self._get_optimized_masks_file(index)) for index in range(total_samples)
]
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)

Expand All @@ -255,6 +295,16 @@ def _process_mask(sample_index: Any, filename: str) -> None:
return_results=False,
)

os.makedirs(os.path.dirname(mask_classes_file), exist_ok=True)
with open(mask_classes_file, "w") as file:
file.write(str(self.classes))

def _fix_orientation(self, array: npt.NDArray):
"""Fixes orientation such that table is at the bottom & liver on the left."""
array = np.rot90(array)
array = np.flip(array, axis=1)
return array

def _get_image_path(self, sample_index: int) -> str:
"""Returns the corresponding image path."""
sample_dir = self._samples_dirs[sample_index]
Expand All @@ -265,10 +315,15 @@ def _get_masks_dir(self, sample_index: int) -> str:
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "segmentations")

def _get_semantic_labels_filename(self, sample_index: int) -> str:
def _get_optimized_masks_root(self) -> str:
"""Returns the directory of the optimized masks."""
return os.path.join(self._root, f"processed/masks/{self._classes_hash}")

def _get_optimized_masks_file(self, sample_index: int) -> str:
"""Returns the semantic label filename."""
masks_dir = self._get_masks_dir(sample_index)
return os.path.join(masks_dir, "semantic_labels", "masks.nii")
return os.path.join(
f"{self._get_optimized_masks_root()}/{self._samples_dirs[sample_index]}/masks.nii"
)

def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
"""Returns the total amount of slices of a sample."""
Expand All @@ -281,7 +336,7 @@ def _fetch_samples_dirs(self) -> List[str]:
sample_filenames = [
filename
for filename in os.listdir(self._root)
if os.path.isdir(os.path.join(self._root, filename))
if os.path.isdir(os.path.join(self._root, filename)) and re.match(r"^s\d{4}$", filename)
]
return sorted(sample_filenames)

Expand Down
Loading