Skip to content

Commit

Permalink
Cellseg 3d (#319)
Browse files Browse the repository at this point in the history
Add EmbedSeg and CellSeg3d datasets
  • Loading branch information
constantinpape committed Jul 9, 2024
1 parent 37c9719 commit da53c11
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 1 deletion.
2 changes: 2 additions & 0 deletions torch_em/data/datasets/light_microscopy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .cellseg_3d import get_cellseg_3d_loader, get_cellseg_3d_dataset
from .covid_if import get_covid_if_loader, get_covid_if_dataset
from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset
from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset
from .dsb import get_dsb_loader, get_dsb_dataset
from .dynamicnuclearnet import get_dynamicnuclearnet_loader, get_dynamicnuclearnet_dataset
from .embedseg_data import get_embedseg_loader, get_embedseg_dataset
from .hpa import get_hpa_segmentation_loader, get_hpa_segmentation_dataset
from .livecell import get_livecell_loader, get_livecell_dataset
from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset
Expand Down
106 changes: 106 additions & 0 deletions torch_em/data/datasets/light_microscopy/cellseg_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""This dataset contains annotation for nucleus segmentation in 3d fluorescence microscopy from mesoSPIM microscopy.
This dataset is from the publication https://doi.org/10.1101/2024.05.17.594691 .
Please cite it if you use this dataset in your research.
"""

import os
from glob import glob
from typing import Optional, Tuple, Union

import torch_em
from torch.utils.data import Dataset, DataLoader
from .. import util

URL = "https://zenodo.org/records/11095111/files/DATASET_WITH_GT.zip?download=1"
CHECKSUM = "6d8e8d778e479000161fdfea70201a6ded95b3958a703f69def63e69bbddf9d6"


def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool) -> str:
"""Download the CellSeg3d training data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
download: Whether to download the data if it is not present.
Returns:
The filepath to the training data.
"""
url = URL
checksum = CHECKSUM

data_path = os.path.join(path, "DATASET_WITH_GT")
if os.path.exists(data_path):
return data_path

os.makedirs(path, exist_ok=True)
zip_path = os.path.join(path, "cellseg3d.zip")
util.download_source(zip_path, url, download, checksum)
util.unzip(zip_path, path, True)

return data_path


def get_cellseg_3d_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
sample_ids: Optional[Tuple[int, ...]] = None,
download: bool = False,
**kwargs
) -> Dataset:
"""Get the CellSeg3d dataset for segmenting nuclei in 3d fluorescence microscopy.
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
sample_ids: The volume ids to load.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Returns:
The segmentation dataset.
"""
data_root = get_cellseg_3d_data(path, download)

raw_paths = sorted(glob(os.path.join(data_root, "*.tif")))
label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif")))
assert len(raw_paths) == len(label_paths)
if sample_ids is not None:
assert all(sid < len(raw_paths) for sid in sample_ids)
raw_paths = [raw_paths[i] for i in sample_ids]
label_paths = [label_paths[i] for i in sample_ids]

raw_key, label_key = None, None

return torch_em.default_segmentation_dataset(
raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs
)


def get_cellseg_3d_loader(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
batch_size: int,
sample_ids: Optional[Tuple[int, ...]] = None,
download: bool = False,
**kwargs
) -> DataLoader:
"""Get the CellSeg3d dataloder for segmenting nuclei in 3d fluorescence microscopy.
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
sample_ids: The volume ids to load.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Returns:
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_cellseg_3d_dataset(
path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs,
)
loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs)
return loader
122 changes: 122 additions & 0 deletions torch_em/data/datasets/light_microscopy/embedseg_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""This dataset contains annotation for 3d fluorescence microscopy segmentation
that were introduced by the EmbedSeg publication.
This dataset is from the publication https://proceedings.mlr.press/v143/lalit21a.html.
Please cite it if you use this dataset in your research.
"""

import os
from glob import glob
from typing import Tuple, Union

import torch_em
from torch.utils.data import Dataset, DataLoader
from .. import util

URLS = {
"Mouse-Organoid-Cells-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Organoid-Cells-CBG.zip", # noqa
"Mouse-Skull-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Skull-Nuclei-CBG.zip",
"Platynereis-ISH-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-ISH-Nuclei-CBG.zip", # noqa
"Platynereis-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-Nuclei-CBG.zip",
}
CHECKSUMS = {
"Mouse-Organoid-Cells-CBG": "3695ac340473900ace8c37fd7f3ae0d37217de9f2b86c2341f36b1727825e48b",
"Mouse-Skull-Nuclei-CBG": "3600ec261a48bf953820e0536cacd0bb8a5141be6e7435a4cb0fffeb0caf594e",
"Platynereis-ISH-Nuclei-CBG": "bc9284df6f6d691a8e81b47310d95617252cc98ebf7daeab55801b330ba921e0",
"Platynereis-Nuclei-CBG": "448cb7b46f2fe7d472795e05c8d7dfb40f259d94595ad2cfd256bc2aa4ab3be7",
}


def get_embedseg_data(path: Union[os.PathLike, str], name: str, download: bool) -> str:
"""Download the EmbedSeg training data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
name: Name of the dataset to download.
download: Whether to download the data if it is not present.
Returns:
The filepath to the training data.
"""
if name not in URLS:
raise ValueError(f"The dataset name must be in {list(URLS.keys())}. You provided {name}.")

url = URLS[name]
checksum = CHECKSUMS[name]

data_path = os.path.join(path, name)
if os.path.exists(data_path):
return data_path

os.makedirs(path, exist_ok=True)
zip_path = os.path.join(path, f"{name}.zip")
util.download_source(zip_path, url, download, checksum)
util.unzip(zip_path, path, True)

return data_path


def get_embedseg_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
name: str,
split: str = "train",
download: bool = False,
**kwargs
) -> Dataset:
"""Get an EmbedSeg dataset for 3d fluorescence microscopy segmentation.
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
name: Name of the dataset to download.
split: The split to use for the dataset.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Returns:
The segmentation dataset.
"""
data_root = get_embedseg_data(path, name, download)

raw_paths = sorted(glob(os.path.join(data_root, split, "images", "*.tif")))
label_paths = sorted(glob(os.path.join(data_root, split, "masks", "*.tif")))
assert len(raw_paths) > 0
assert len(raw_paths) == len(label_paths)

raw_key, label_key = None, None

return torch_em.default_segmentation_dataset(
raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs
)


def get_embedseg_loader(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
batch_size: int,
name: str,
split: str = "train",
download: bool = False,
**kwargs
) -> DataLoader:
"""Get an EmbedSeg dataloader for 3d fluorescence microscopy segmentation.
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
name: Name of the dataset to download.
split: The split to use for the dataset.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Returns:
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_embedseg_dataset(
path, name=name, split=split, patch_shape=patch_shape, download=download, **ds_kwargs,
)
loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs)
return loader
2 changes: 1 addition & 1 deletion torch_em/util/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def predict_block(block_id):

if mask is not None:
mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
mask_block = mask_block[inner_bb]
mask_block = mask_block[inner_bb].astype("bool")
if mask_block.sum() == 0:
return

Expand Down

0 comments on commit da53c11

Please sign in to comment.