Skip to content

Commit

Permalink
Merge pull request #9 from angelolab/pyramid_ome_tiffs
Browse files Browse the repository at this point in the history
Add support for pyramid ome tiffs
  • Loading branch information
JLrumberger authored Apr 18, 2024
2 parents 5250d0b + 179b95d commit af30966
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def initialize_model(self, padding="reflect"):
)
if not os.path.exists(self.checkpoint_path):
local_dir = os.path.join(path, "assets")
os.makedirs(local_dir, exist_ok=True)
print("Downloading weights from Hugging Face Hub...")
self.checkpoint_path = hf_hub_download(
repo_id="JLrumberger/Nimbus-Inference",
Expand Down Expand Up @@ -281,7 +282,7 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"):
pad_h0, pad_w0 = np.array(tile_size) - (
np.array(image.shape[-2:]) % np.array(output_shape)
)
pad_h1, pad_w1 = pad_h0 // 2, pad_h0 - pad_h0 // 2
pad_h1, pad_w1 = pad_h0 // 2, pad_w0 // 2
pad_h0, pad_w0 = pad_h0 - pad_h1, pad_w0 - pad_w1
image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode)
b, c = image.shape[:2]
Expand Down
42 changes: 30 additions & 12 deletions src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def get_shape(self):
"""
with tifffile.imread(str(self.fpath), aszarr=True) as store:
z = zarr.open(store, mode='r')
shape = z.shape
if hasattr(z, "shape"):
zz = z
else:
zz = z[0]
shape = zz.shape
return shape

def get_channel(self, channel_name: str):
Expand All @@ -90,10 +94,17 @@ def get_channel(self, channel_name: str):
z = zarr.open(store, mode='r')
# correct DimOrder, often DimOrder is TZCYX, but image is stored as CYX,
# thus we remove the trailing dimensions
if hasattr(z, "shape"):
zz = z
else:
zz = z[0]
dim_order = self.metadata["DimOrder"]
dim_order = dim_order[-len(z.shape):]
if len(dim_order) > len(zz.shape):
dim_order = dim_order.replace("T", "")
if len(dim_order) > len(zz.shape):
dim_order = dim_order.replace("Z", "")
channel_idx = dim_order.find("C")
slice_string = "z[" + ":," * channel_idx + str(idx) + "]"
slice_string = "zz[" + ":," * channel_idx + str(idx) + "]"
channel = eval(slice_string)
return channel

Expand Down Expand Up @@ -232,7 +243,10 @@ def get_segmentation(self, fov: str):
idx = self.fovs.index(fov)
fov_path = self.fov_paths[idx]
instance_path = self.segmentation_naming_convention(fov_path)
instance_mask = np.squeeze(io.imread(instance_path))
if isinstance(instance_path, str):
instance_mask = np.squeeze(io.imread(instance_path))
else:
instance_mask = instance_path
return instance_mask


Expand Down Expand Up @@ -345,7 +359,7 @@ def predict_fovs(
for fov_path, fov in zip(dataset.fov_paths, dataset.fovs):
print(f"Predicting {fov_path}...")
out_fov_path = os.path.join(
os.path.normpath(output_dir), os.path.basename(fov_path)
os.path.normpath(output_dir), os.path.basename(fov_path).replace(suffix, "")
)
df_fov = pd.DataFrame()
instance_mask = dataset.get_segmentation(fov)
Expand All @@ -356,9 +370,9 @@ def predict_fovs(
scale = 0.5
input_data = np.squeeze(input_data)
_, h,w = input_data.shape
img = cv2.resize(input_data[0], [int(h*scale), int(w*scale)])
img = cv2.resize(input_data[0], [int(w*scale), int(h*scale)])
binary_mask = cv2.resize(
input_data[1], [int(h*scale), int(w*scale)], interpolation=0
input_data[1], [int(w*scale), int(h*scale)], interpolation=0
)
input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...]
if test_time_augmentation:
Expand All @@ -375,7 +389,7 @@ def predict_fovs(
)
prediction = np.squeeze(prediction)
if half_resolution:
prediction = cv2.resize(prediction, (h, w))
prediction = cv2.resize(prediction, (w, h))
df = pd.DataFrame(segment_mean(instance_mask, prediction))
if df_fov.empty:
df_fov["label"] = df["label"]
Expand Down Expand Up @@ -432,8 +446,11 @@ def calculate_normalization(dataset: MultiplexDataset, quantile: float):
for channel in dataset.channels:
mplex_img = dataset.get_channel(dataset.fovs[0], channel)
mplex_img = mplex_img.astype(np.float32)
foreground = mplex_img[mplex_img > 0]
normalization_values[channel] = np.quantile(foreground, quantile)
if np.any(mplex_img):
foreground = mplex_img[mplex_img > 0]
normalization_values[channel] = np.quantile(foreground, quantile)
else:
normalization_values[channel] = None
return normalization_values


Expand All @@ -458,7 +475,7 @@ def prepare_normalization_dict(
random.shuffle(fov_paths)
fov_paths = fov_paths[:n_subset]
print("Iterate over fovs...")
if n_jobs > 1:
if n_jobs > 1 and len(fov_paths) > 1:
normalization_values = Parallel(n_jobs=n_jobs)(
delayed(calculate_normalization)(
MultiplexDataset(
Expand All @@ -480,7 +497,8 @@ def prepare_normalization_dict(
for channel, normalization_value in norm_dict.items():
if channel not in normalization_dict:
normalization_dict[channel] = []
normalization_dict[channel].append(normalization_value)
if normalization_value:
normalization_dict[channel].append(normalization_value)
if n_jobs > 1:
get_reusable_executor().shutdown(wait=True)
for channel in normalization_dict.keys():
Expand Down
10 changes: 9 additions & 1 deletion src/nimbus_inference/viewer_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import numpy as np
from natsort import natsorted
from skimage.segmentation import find_boundaries
from skimage.transform import rescale
from nimbus_inference.utils import MultiplexDataset

class NimbusViewer(object):
def __init__(
self, dataset: MultiplexDataset, output_dir: str, img_width='600px', suffix=".tiff"
self, dataset: MultiplexDataset, output_dir: str, img_width='600px', suffix=".tiff",
max_resolution=(2048, 2048)
):
"""Viewer for Nimbus application.
Args:
Expand All @@ -20,11 +22,13 @@ def __init__(
segmentation_naming_convention (fn): Function that maps input path to segmentation path
img_width (str): Width of images in viewer.
suffix (str): Suffix of images in dataset.
max_resolution (tuple): Maximum resolution of images in viewer.
"""
self.image_width = img_width
self.dataset = dataset
self.output_dir = output_dir
self.suffix = suffix
self.max_resolution = max_resolution
self.fov_names = natsorted(copy(self.dataset.fovs))
self.update_button = widgets.Button(description="Update Image")
self.update_button.on_click(self.update_button_click)
Expand Down Expand Up @@ -199,6 +203,10 @@ def update_img(self, image_viewer, composite_image):
composite_image (np.array): Composite image to display.
"""
# Convert composite image to bytes and assign it to the output_image widget
if composite_image.shape[0] > self.max_resolution[0] or composite_image.shape[1] > self.max_resolution[1]:
scale = float(np.max(self.max_resolution)/np.max(composite_image.shape))
composite_image = rescale(composite_image, (scale, scale, 1), preserve_range=True)
composite_image = composite_image.astype(np.uint8)
with BytesIO() as output_buffer:
io.imsave(output_buffer, composite_image, format="png")
output_buffer.seek(0)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def segmentation_naming_convention(fov_path):
# check if we get the correct channels and fov_paths
dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".ome.tiff")
assert len(dataset) == 1
assert dataset.channels == ["CD4", "CD56"]
assert set(dataset.channels) == set(["CD4", "CD56"])
assert dataset.fov_paths == fov_paths
assert dataset.multi_channel == True
cd4_channel = io.imread(fov_paths[0])[0]
Expand All @@ -312,7 +312,7 @@ def segmentation_naming_convention(fov_path):
)
dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff")
assert len(dataset) == 1
assert dataset.channels == ["CD4", "CD56"]
assert set(dataset.channels) == set(["CD4", "CD56"])
assert dataset.fov_paths == fov_paths
assert dataset.multi_channel == False
cd4_channel_ = dataset.get_channel(fov="fov_0", channel="CD4")
Expand Down

0 comments on commit af30966

Please sign in to comment.