Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ISLES dataset #254

Merged
merged 9 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions scripts/datasets/medical/check_isles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import sys

from torch_em.util.debug import check_loader
from torch_em.data import MinForegroundSampler
from torch_em.data.datasets.medical import get_isles_loader


sys.path.append("..")


def check_isles():
from util import ROOT

loader = get_isles_loader(
path=os.path.join(ROOT, "isles"),
patch_shape=(1, 112, 112),
batch_size=2,
ndim=2,
modality=None,
download=True,
sampler=MinForegroundSampler(min_fraction=0.001),
)

check_loader(loader, 8, plt=True, save_path="./test.png")


if __name__ == "__main__":
check_isles()
9 changes: 5 additions & 4 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .acdc import get_acdc_dataset, get_acdc_loader
from .acouslic_ai import get_acouslic_ai_dataset, get_acouslic_ai_loader
from .autopet import get_autopet_loader
from .autopet import get_autopet_dataset, get_autopet_loader
from .amos import get_amos_dataset, get_amos_loader
from .btcv import get_btcv_dataset, get_btcv_loader
from .busi import get_busi_dataset, get_busi_loader
Expand All @@ -13,19 +13,20 @@
from .curvas import get_curvas_dataset, get_curvas_loader
from .dca1 import get_dca1_dataset, get_dca1_loader
from .drive import get_drive_dataset, get_drive_loader
from .jsrt import get_jsrt_dataset, get_jsrt_loader
from .kvasir import get_kvasir_dataset, get_kvasir_loader
from .mbh_seg import get_mbh_seg_dataset, get_mbh_seg_loader
from .duke_liver import get_duke_liver_dataset, get_duke_liver_loader
from .feta24 import get_feta24_dataset, get_feta24_loader
from .han_seg import get_han_seg_dataset, get_han_seg_loader
from .hil_toothseg import get_hil_toothseg_dataset, get_hil_toothseg_loader
from .idrid import get_idrid_dataset, get_idrid_loader
from .isic import get_isic_dataset, get_isic_loader
from .isles import get_isles_dataset, get_isles_loader
from .jsrt import get_jsrt_dataset, get_jsrt_loader
from .jnuifm import get_jnuifm_dataset, get_jnuifm_loader
from .kvasir import get_kvasir_dataset, get_kvasir_loader
from .leg_3d_us import get_leg_3d_us_dataset, get_leg_3d_us_loader
from .lgg_mri import get_lgg_mri_dataset, get_lgg_mri_loader
from .m2caiseg import get_m2caiseg_dataset, get_m2caiseg_loader
from .mbh_seg import get_mbh_seg_dataset, get_mbh_seg_loader
from .micro_usp import get_micro_usp_dataset, get_micro_usp_loader
from .montgomery import get_montgomery_dataset, get_montgomery_loader
from .msd import get_msd_dataset, get_msd_loader
Expand Down
141 changes: 141 additions & 0 deletions torch_em/data/datasets/medical/isles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""The ISLES dataset contains annotations for ischemic stroke lesion segmentation
in multimodal brain MRI scans.

The database is located at https://doi.org/10.5281/zenodo.7960856.
This dataset is from the ISLES 2022 Challenge - https://doi.org/10.1038/s41597-022-01875-5.
Please cite it if you use this dataset for a publication.
"""

import os
from glob import glob
from typing import Union, Tuple, Optional, Literal, List

from torch.utils.data import Dataset, DataLoader

import torch_em

from .. import util


URL = "https://zenodo.org/records/7960856/files/ISLES-2022.zip"
CHECKSUM = "f374895e383f725ddd280db41ef36ed975277c33de0e587a631ca7ea7ad45d6b"


def get_isles_data(path: Union[os.PathLike, str], download: bool = False) -> str:
"""Download the ISLES dataset.

Args:
path: Filepath to a folder where the data is downloaded for further processing.
download: Whether to download the data if it is not present.

Returns:
Filepath where the data is downloaded.
"""
data_dir = os.path.join(path, "ISLES-2022")
if os.path.exists(data_dir):
return data_dir

os.makedirs(path, exist_ok=True)

zip_path = os.path.join(path, "ISLES-2022.zip")
util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
util.unzip(zip_path=zip_path, dst=path)

return data_dir


def get_isles_paths(
path: Union[os.PathLike, str], modality: Optional[Literal["dwi", "adc"]] = None, download: bool = False
) -> Tuple[List[str], List[str]]:
"""Get paths to the ISLES data.

Args:
path: Filepath to a folder where the data is downloaded for further processing.
modality: The choice of modality for MRIs. Either 'dwi' or 'adc'.
download: Whether to download the data if it is not present.

Returns:
List of filepaths for the image data.
List of filepaths for the label data.
"""
data_dir = get_isles_data(path=path, download=download)

gt_paths = sorted(glob(os.path.join(data_dir, "derivatives", "sub-*", "**", "*.nii.gz"), recursive=True))

dwi_paths = sorted(glob(os.path.join(data_dir, "sub-*", "**", "dwi", "*_dwi.nii.gz"), recursive=True))
adc_paths = sorted(glob(os.path.join(data_dir, "sub-*", "**", "dwi", "*_adc.nii.gz"), recursive=True))

if modality is None:
image_paths = [(dwi_path, adc_path) for dwi_path, adc_path in zip(dwi_paths, adc_paths)]
else:
if modality == "dwi":
image_paths = dwi_paths
elif modality == "adc":
image_paths = adc_paths
else:
raise ValueError(f"'{modality}' is not a valid modality.")

return image_paths, gt_paths


def get_isles_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
modality: Optional[Literal["dwi", "adc"]] = None,
download: bool = False,
**kwargs
) -> Dataset:
"""Get the ISLES dataset for segmentation of ischemic stroke lesion.

Args:
path: Filepath to a folder where the data is downloaded for further processing.
patch_shape: The patch shape to use for training.
modality: The choice of modality for MRIs. Either 'dwi' or 'adc'.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.

Returns:
The segmentation dataset.
"""
image_paths, gt_paths = get_isles_paths(path, modality, download)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=gt_paths,
label_key="data",
patch_shape=patch_shape,
with_channels=modality is None,
**kwargs
)
if "sampler" in kwargs:
for ds in dataset.datasets:
ds.max_sampling_attempts = 5000

return dataset


def get_isles_loader(
path: Union[os.PathLike, str],
batch_size: int,
patch_shape: Tuple[int, int],
modality: Optional[Literal["dwi", "adc"]] = None,
download: bool = False,
**kwargs
) -> DataLoader:
"""Get the ISLES dataloader for segmentation of ischemic stroke lesion.

Args:
path: Filepath to a folder where the data is downloaded for further processing.
batch_size: The batch size for training.
patch_shape: The patch shape to use for training.
modality: The choice of modality for MRIs. Either 'dwi' or 'adc'.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.

Returns:
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_isles_dataset(path=path, patch_shape=patch_shape, modality=modality, download=download, **ds_kwargs)
return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)