From da53c117d59fb4a91bca77cb700852d7675a8255 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 9 Jul 2024 20:58:51 +0200 Subject: [PATCH] Cellseg 3d (#319) Add EmbedSeg and CellSeg3d datasets --- .../datasets/light_microscopy/__init__.py | 2 + .../datasets/light_microscopy/cellseg_3d.py | 106 +++++++++++++++ .../light_microscopy/embedseg_data.py | 122 ++++++++++++++++++ torch_em/util/prediction.py | 2 +- 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 torch_em/data/datasets/light_microscopy/cellseg_3d.py create mode 100644 torch_em/data/datasets/light_microscopy/embedseg_data.py diff --git a/torch_em/data/datasets/light_microscopy/__init__.py b/torch_em/data/datasets/light_microscopy/__init__.py index 97e41d35..53d51066 100644 --- a/torch_em/data/datasets/light_microscopy/__init__.py +++ b/torch_em/data/datasets/light_microscopy/__init__.py @@ -1,8 +1,10 @@ +from .cellseg_3d import get_cellseg_3d_loader, get_cellseg_3d_dataset from .covid_if import get_covid_if_loader, get_covid_if_dataset from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset from .dsb import get_dsb_loader, get_dsb_dataset from .dynamicnuclearnet import get_dynamicnuclearnet_loader, get_dynamicnuclearnet_dataset +from .embedseg_data import get_embedseg_loader, get_embedseg_dataset from .hpa import get_hpa_segmentation_loader, get_hpa_segmentation_dataset from .livecell import get_livecell_loader, get_livecell_dataset from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset diff --git a/torch_em/data/datasets/light_microscopy/cellseg_3d.py b/torch_em/data/datasets/light_microscopy/cellseg_3d.py new file mode 100644 index 00000000..60e5b8ea --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/cellseg_3d.py @@ -0,0 +1,106 @@ +"""This dataset contains annotation for nucleus segmentation in 3d fluorescence microscopy from mesoSPIM microscopy. + +This dataset is from the publication https://doi.org/10.1101/2024.05.17.594691 . +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from typing import Optional, Tuple, Union + +import torch_em +from torch.utils.data import Dataset, DataLoader +from .. import util + +URL = "https://zenodo.org/records/11095111/files/DATASET_WITH_GT.zip?download=1" +CHECKSUM = "6d8e8d778e479000161fdfea70201a6ded95b3958a703f69def63e69bbddf9d6" + + +def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool) -> str: + """Download the CellSeg3d training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + url = URL + checksum = CHECKSUM + + data_path = os.path.join(path, "DATASET_WITH_GT") + if os.path.exists(data_path): + return data_path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "cellseg3d.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + return data_path + + +def get_cellseg_3d_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + sample_ids: Optional[Tuple[int, ...]] = None, + download: bool = False, + **kwargs +) -> Dataset: + """Get the CellSeg3d dataset for segmenting nuclei in 3d fluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_ids: The volume ids to load. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + data_root = get_cellseg_3d_data(path, download) + + raw_paths = sorted(glob(os.path.join(data_root, "*.tif"))) + label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif"))) + assert len(raw_paths) == len(label_paths) + if sample_ids is not None: + assert all(sid < len(raw_paths) for sid in sample_ids) + raw_paths = [raw_paths[i] for i in sample_ids] + label_paths = [label_paths[i] for i in sample_ids] + + raw_key, label_key = None, None + + return torch_em.default_segmentation_dataset( + raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs + ) + + +def get_cellseg_3d_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + sample_ids: Optional[Tuple[int, ...]] = None, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the CellSeg3d dataloder for segmenting nuclei in 3d fluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The volume ids to load. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_cellseg_3d_dataset( + path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs, + ) + loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/light_microscopy/embedseg_data.py b/torch_em/data/datasets/light_microscopy/embedseg_data.py new file mode 100644 index 00000000..38e32ce1 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/embedseg_data.py @@ -0,0 +1,122 @@ +"""This dataset contains annotation for 3d fluorescence microscopy segmentation +that were introduced by the EmbedSeg publication. + +This dataset is from the publication https://proceedings.mlr.press/v143/lalit21a.html. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from typing import Tuple, Union + +import torch_em +from torch.utils.data import Dataset, DataLoader +from .. import util + +URLS = { + "Mouse-Organoid-Cells-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Organoid-Cells-CBG.zip", # noqa + "Mouse-Skull-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Skull-Nuclei-CBG.zip", + "Platynereis-ISH-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-ISH-Nuclei-CBG.zip", # noqa + "Platynereis-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-Nuclei-CBG.zip", +} +CHECKSUMS = { + "Mouse-Organoid-Cells-CBG": "3695ac340473900ace8c37fd7f3ae0d37217de9f2b86c2341f36b1727825e48b", + "Mouse-Skull-Nuclei-CBG": "3600ec261a48bf953820e0536cacd0bb8a5141be6e7435a4cb0fffeb0caf594e", + "Platynereis-ISH-Nuclei-CBG": "bc9284df6f6d691a8e81b47310d95617252cc98ebf7daeab55801b330ba921e0", + "Platynereis-Nuclei-CBG": "448cb7b46f2fe7d472795e05c8d7dfb40f259d94595ad2cfd256bc2aa4ab3be7", +} + + +def get_embedseg_data(path: Union[os.PathLike, str], name: str, download: bool) -> str: + """Download the EmbedSeg training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: Name of the dataset to download. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + if name not in URLS: + raise ValueError(f"The dataset name must be in {list(URLS.keys())}. You provided {name}.") + + url = URLS[name] + checksum = CHECKSUMS[name] + + data_path = os.path.join(path, name) + if os.path.exists(data_path): + return data_path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, f"{name}.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + return data_path + + +def get_embedseg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + name: str, + split: str = "train", + download: bool = False, + **kwargs +) -> Dataset: + """Get an EmbedSeg dataset for 3d fluorescence microscopy segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + name: Name of the dataset to download. + split: The split to use for the dataset. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + data_root = get_embedseg_data(path, name, download) + + raw_paths = sorted(glob(os.path.join(data_root, split, "images", "*.tif"))) + label_paths = sorted(glob(os.path.join(data_root, split, "masks", "*.tif"))) + assert len(raw_paths) > 0 + assert len(raw_paths) == len(label_paths) + + raw_key, label_key = None, None + + return torch_em.default_segmentation_dataset( + raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs + ) + + +def get_embedseg_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + name: str, + split: str = "train", + download: bool = False, + **kwargs +) -> DataLoader: + """Get an EmbedSeg dataloader for 3d fluorescence microscopy segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + name: Name of the dataset to download. + split: The split to use for the dataset. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_embedseg_dataset( + path, name=name, split=split, patch_shape=patch_shape, download=download, **ds_kwargs, + ) + loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/util/prediction.py b/torch_em/util/prediction.py index e4ae86ed..eb6ba53d 100644 --- a/torch_em/util/prediction.py +++ b/torch_em/util/prediction.py @@ -182,7 +182,7 @@ def predict_block(block_id): if mask is not None: mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) - mask_block = mask_block[inner_bb] + mask_block = mask_block[inner_bb].astype("bool") if mask_block.sum() == 0: return