Skip to content

Commit

Permalink
Add padding for raw and label channels
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 20, 2023
1 parent c20c8f0 commit 944cce9
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 19 deletions.
8 changes: 6 additions & 2 deletions scripts/datasets/check_monusac.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@ def check_monusac():
download=True,
patch_shape=(512, 512),
batch_size=2,
split="train"
split="train",
organ_type=["breast", "lung"]
)
print("Length of train loader: ", len(train_loader))
check_loader(train_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./monusac_train.png")

test_loader = get_monusac_loader(
path=MONUSAC_ROOT,
download=True,
patch_shape=(512, 512),
batch_size=1,
split="test"
split="test",
organ_type=["breast", "prostate"]
)
print("Length of test loader: ", len(test_loader))
check_loader(test_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./monusac_test.png")


Expand Down
43 changes: 34 additions & 9 deletions torch_em/data/datasets/monusac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import shutil
from glob import glob
from tqdm import tqdm
from pathlib import Path
from typing import Optional, List

import imageio.v2 as imageio
import imageio.v3 as imageio

import torch_em
from torch_em.data.datasets import util
from . import util


URL = {
Expand Down Expand Up @@ -72,12 +73,16 @@ def _process_monusac(path, split):
os.makedirs(root_img_save_dir, exist_ok=True)
os.makedirs(root_label_save_dir, exist_ok=True)

all_patient_dir = sorted(glob(os.path.join(path, "MoNuSAC_images_and_annotations", "*")))
all_patient_dir = sorted(glob(os.path.join(path, "MoNuSAC*", "*")))

for patient_dir in tqdm(all_patient_dir, desc=f"Converting {split} inputs for all patients"):
all_img_dir = sorted(glob(os.path.join(patient_dir, "*.tif")))
all_xml_label_dir = sorted(glob(os.path.join(patient_dir, "*.xml")))

if len(all_img_dir) != len(all_xml_label_dir):
_convert_missing_tif_from_svs(patient_dir)
all_img_dir = sorted(glob(os.path.join(patient_dir, "*.tif")))

assert len(all_img_dir) == len(all_xml_label_dir)

for img_path, xml_label_path in zip(all_img_dir, all_xml_label_dir):
Expand All @@ -94,6 +99,30 @@ def _process_monusac(path, split):
shutil.rmtree(glob(os.path.join(path, "MoNuSAC*"))[0])


def _convert_missing_tif_from_svs(patient_dir):
"""This function activates when we see some missing tiff inputs (and converts svs to tiff)
Cause: Happens only in the test split, maybe while converting the data, some were missed
Fix: We have the original svs scans. We convert the svs scans to tiff
"""
all_svs_dir = sorted(glob(os.path.join(patient_dir, "*.svs")))
for svs_path in all_svs_dir:
save_tif_path = os.path.splitext(svs_path)[0] + ".tif"
if not os.path.exists(save_tif_path):
img_array = util.convert_svs_to_tiff(svs_path)
imageio.imwrite(save_tif_path, img_array)


def get_patient_id(path, split_wrt="-01Z-00-"):
"""Gets us the patient id in the expected format
Input Names: "TCGA-<XX>-<XXXX>-01z-00-DX<X>-(<X>, <00X>).tif" (example: TCGA-2Z-A9JG-01Z-00-DX1_1.tif)
Expected: "TCGA-<XX>-<XXXX>" (example: TCGA-2Z-A9JG)
"""
patient_image_id = Path(path).stem
patient_id = patient_image_id.split(split_wrt)[0]
return patient_id


def get_monusac_dataset(
path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False,
offsets=None, boundaries=False, binary=False, **kwargs
Expand All @@ -109,12 +138,8 @@ def get_monusac_dataset(
# get all patients for multiple organ selection
all_organ_splits = sum([ORGAN_SPLITS[split][o] for o in organ_type], [])

image_paths = [
_path for _path in image_paths if os.path.split(_path)[-1].split(".")[0] in all_organ_splits
]
label_paths = [
_path for _path in label_paths if os.path.split(_path)[-1].split(".")[0] in all_organ_splits
]
image_paths = [_path for _path in image_paths if get_patient_id(_path) in all_organ_splits]
label_paths = [_path for _path in label_paths if get_patient_id(_path) in all_organ_splits]

kwargs, _ = util.add_instance_label_transform(
kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
Expand Down
27 changes: 27 additions & 0 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,30 @@ def generate_labeled_array_from_xml(shape, xml_file):
r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
mask[r, c] = i
return mask


def convert_svs_to_tiff(path, location=(0, 0), level=0, img_size=None):
"""Converts .svs files to .tif format
Argument:
- path: [str] - Path to the svs file
(below mentioned arguments are used for multi-resolution images)
- location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
- level: [int] - target level used to read the image (default: 0)
- img_size: tuple[int, int] - expected size of the image (default: None -> obtains the original shape at the expected level)
Returns:
the image as numpy array
"""
assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"

from tiffslide import TiffSlide

_slide = TiffSlide(path)

if img_size is None:
img_size = _slide.level_dimensions[0]

img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True)

return img_arr
7 changes: 3 additions & 4 deletions torch_em/data/image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@ def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channel
if have_raw_channels and channel_first:
shape = shape[1:]
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
if have_raw_channels or have_label_channels:
raise NotImplementedError("Padding is not implemented for data with channels")
assert len(shape) == len(self.patch_shape)
pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
raw, labels = np.pad(raw, pw), np.pad(labels, pw)
pw_raw = [*pw, (0, 0)] if have_raw_channels else pw
pw_labels = [*pw, (0, 0)] if have_label_channels else pw
raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw_labels)
return raw, labels

def _get_sample(self, index):
Expand Down
6 changes: 2 additions & 4 deletions torch_em/data/raw_image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,9 @@ def _sample_bounding_box(self, shape):
def _ensure_patch_shape(self, raw, have_raw_channels):
shape = raw.shape
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
if have_raw_channels:
raise NotImplementedError("Padding is not implemented for data with channels")
assert len(shape) == len(self.patch_shape)
pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
raw = np.pad(raw, pw)
pw_raw = [*pw, (0, 0)] if have_raw_channels else pw
raw = np.pad(raw, pw_raw)
return raw

def _get_sample(self, index):
Expand Down

0 comments on commit 944cce9

Please sign in to comment.