Skip to content

Commit

Permalink
Add deepbacs dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 14, 2023
1 parent f7f507c commit 2d1da06
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 4 deletions.
12 changes: 12 additions & 0 deletions scripts/datasets/check_deepbacs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch_em.data.datasets import get_deepbacs_loader
from torch_em.util.debug import check_loader


def check_deepbacs():
loader = get_deepbacs_loader("./deepbacs", "test", bac_type="mixed", download=True,
patch_shape=(256, 256), batch_size=1, shuffle=True)
check_loader(loader, 15, instance_labels=True)


if __name__ == "__main__":
check_deepbacs()
1 change: 1 addition & 0 deletions torch_em/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .cem import get_cem_mitolab_loader
from .covid_if import get_covid_if_loader
from .cremi import get_cremi_loader
from .deepbacs import get_deepbacs_loader
from .dsb import get_dsb_loader
from .hpa import get_hpa_segmentation_loader
from .isbi2012 import get_isbi_loader
Expand Down
54 changes: 54 additions & 0 deletions torch_em/data/datasets/deepbacs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os

from ...segmentation import default_segmentation_loader
from . import util

URLS = {
"s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1",
"e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1",
"b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1",
"mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1",
}
CHECKSUMS = {
"s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307",
"e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17",
"b_subtilis": "1",
"mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0",
}


def _require_deebacs_dataset(path, bac_type, download):
os.makedirs(path, exist_ok=True)

zip_path = os.path.join(path, f"{bac_type}.zip")
if not os.path.exists(zip_path):
util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
util.unzip(zip_path, os.path.join(path, bac_type))


def _get_paths(path, bac_type, split):
# the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
# mixed is the combination of all other types
if bac_type != "mixed":
raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}")
image_folder = os.path.join(
path, bac_type, "training" if split == "train" else "test", "source"
)
label_folder = os.path.join(
path, bac_type, "training" if split == "train" else "test", "target"
)
return image_folder, label_folder


def get_deepbacs_loader(path, split, bac_type="mixed", download=False, **kwargs):
assert split in ("train", "test")
bac_types = list(URLS.keys())
assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"

data_folder = os.path.join(path, bac_type)
if not os.path.exists(data_folder):
_require_deebacs_dataset(path, bac_type, download)

image_folder, label_folder = _get_paths(path, bac_type, split)

return default_segmentation_loader(image_folder, "*.tif", label_folder, "*.tif", **kwargs)
4 changes: 2 additions & 2 deletions torch_em/data/image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def _get_sample(self, index):
index = np.random.randint(0, len(self.raw_images))
# these are just the file paths
raw, label = self.raw_images[index], self.label_images[index]
raw = load_image(raw)
label = load_image(label)
raw = load_image(raw, memmap=False)
label = load_image(label, memmap=False)

have_raw_channels = raw.ndim == 3
have_label_channels = label.ndim == 3
Expand Down
4 changes: 2 additions & 2 deletions torch_em/util/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def supports_memmap(image_path):
return True


def load_image(image_path):
if supports_memmap(image_path):
def load_image(image_path, memmap=True):
if supports_memmap(image_path) and memmap:
return tifffile.memmap(image_path, mode="r")
elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"):
return tifffile.imread(image_path)
Expand Down

0 comments on commit 2d1da06

Please sign in to comment.