Skip to content

Commit

Permalink
Small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 9, 2024
1 parent c0bf70c commit 0e29688
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions torch_em/data/datasets/light_microscopy/cellseg_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import os
from glob import glob
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch_em
from torch.utils.data import Dataset, DataLoader
Expand Down Expand Up @@ -44,6 +44,7 @@ def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool) -> str:
def get_cellseg_3d_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
sample_ids: Optional[Tuple[int, ...]] = None,
download: bool = False,
**kwargs
) -> Dataset:
Expand All @@ -52,6 +53,7 @@ def get_cellseg_3d_dataset(
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
sample_ids: The volume ids to load.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Expand All @@ -62,6 +64,11 @@ def get_cellseg_3d_dataset(

raw_paths = sorted(glob(os.path.join(data_root, "*.tif")))
label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif")))
assert len(raw_paths) == len(label_paths)
if sample_ids is not None:
assert all(sid < len(raw_paths) for sid in sample_ids)
raw_paths = [raw_paths[i] for i in sample_ids]
label_paths = [label_paths[i] for i in sample_ids]

raw_key, label_key = None, None

Expand All @@ -74,6 +81,7 @@ def get_cellseg_3d_loader(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
batch_size: int,
sample_ids: Optional[Tuple[int, ...]] = None,
download: bool = False,
**kwargs
) -> DataLoader:
Expand All @@ -83,6 +91,7 @@ def get_cellseg_3d_loader(
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
sample_ids: The volume ids to load.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Expand All @@ -91,7 +100,7 @@ def get_cellseg_3d_loader(
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_cellseg_3d_dataset(
path, patch_shape, download=download, **ds_kwargs,
path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs,
)
loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs)
return loader
2 changes: 1 addition & 1 deletion torch_em/util/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def predict_block(block_id):

if mask is not None:
mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False)
mask_block = mask_block[inner_bb]
mask_block = mask_block[inner_bb].astype("bool")
if mask_block.sum() == 0:
return

Expand Down

0 comments on commit 0e29688

Please sign in to comment.