Skip to content

Commit

Permalink
add utils to create metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Aug 28, 2024
1 parent a5897c9 commit 25bdd41
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 12 deletions.
32 changes: 27 additions & 5 deletions src/ngio/ngff_meta/fractal_image_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def scaling(self) -> float:
raise ValueError(f"Unknown space unit: {self}")
return scaling_factor

@classmethod
def allowed_names(self) -> list[str]:
"""Get the allowed space axis names."""
return list(SpaceUnits.__members__.keys())


class SpaceNames(str, Enum):
"""Allowed space axis names."""
Expand All @@ -83,11 +88,17 @@ class SpaceNames(str, Enum):
y = "y"
z = "z"

@classmethod
def allowed_names(self) -> list[str]:
"""Get the allowed space axis names."""
return list(SpaceNames.__members__.keys())


class TimeUnits(str, Enum):
"""Allowed time units."""

s = "seconds"
seconds = "seconds"
s = "s"

def scaling(self) -> float:
"""Get the scaling factor of the time unit (relative to seconds)."""
Expand All @@ -99,12 +110,22 @@ def scaling(self) -> float:
raise ValueError(f"Unknown time unit: {self}")
return scaling_factor

@classmethod
def allowed_names(self) -> list[str]:
"""Get the allowed time axis names."""
return list(TimeUnits.__members__.keys())


class TimeNames(str, Enum):
"""Allowed time axis names."""

t = "t"

@classmethod
def allowed_names(self) -> list[str]:
"""Get the allowed time axis names."""
return list(TimeNames.__members__.keys())


class Axis(BaseModel):
"""Axis infos model.
Expand All @@ -128,29 +149,30 @@ def _check_consistency(self) -> "Axis":
raise ValueError("Channel axes must not have units.")

if self.type == AxisType.time:
print(self)
self.name = TimeNames(self.name)
if not isinstance(self.unit, TimeUnits):
raise ValueError(
"Time axes must have time units."
f" {self.unit} in {list(TimeUnits.__members__.keys())}"
f" {self.unit} in {TimeUnits.allowed_names()}"
)
if not isinstance(self.name, TimeNames):
raise ValueError(
f"Time axes must have time names. "
f"{self.name} in {list(TimeNames.__members__.keys())}"
f"{self.name} in {TimeNames.allowed_names()}"
)

if self.type == AxisType.space:
self.name = SpaceNames(self.name)
if not isinstance(self.unit, SpaceUnits):
raise ValueError(
"Space axes must have space units."
f" {self.unit} in {list(SpaceUnits.__members__.keys())}"
f" {self.unit} in {SpaceUnits.allowed_names()}"
)
if not isinstance(self.name, SpaceNames):
raise ValueError(
f"Space axes must have space names. "
f"{self.name} in {list(SpaceNames.__members__.keys())}"
f"{self.name} in {SpaceNames.allowed_names()}"
)
return self

Expand Down
147 changes: 140 additions & 7 deletions src/ngio/ngff_meta/utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,171 @@
"""Utility functions for creating and modifying metadata."""

from typing import Any

from ngio.ngff_meta.fractal_image_meta import (
Axis,
Channel,
Dataset,
FractalImageMeta,
FractalLabelMeta,
Multiscale,
Omero,
ScaleCoordinateTransformation,
SpaceNames,
SpaceUnits,
TimeNames,
TimeUnits,
)


def create_image_metadata(
def _compute_scale(axis_order, pixel_sizes, time_spacing):
scale = []

pixel_sizes_dict = {
"z": pixel_sizes[0],
"x": pixel_sizes[1],
"y": pixel_sizes[2],
}

for ax in axis_order:
if ax in TimeNames.allowed_names():
scale.append(time_spacing)
elif ax in SpaceNames.allowed_names():
scale.append(pixel_sizes_dict[ax])
else:
scale.append(1.0)

return scale


def _create_image_metadata(
axis_order: list[str] = ("t", "c", "z", "y", "x"),
pixel_sizes: tuple[float, float, float] = (1.0, 1.0, 1.0),
scaling_factors: tuple[float, float, float] = (1.0, 2.0, 2.0),
pixel_units: SpaceUnits | str = SpaceUnits.micrometer,
time_spacing: float = 1.0,
time_units: TimeUnits | str = TimeUnits.s,
num_levels: int = 5,
channel_names: list[str] | None = None,
channel_wavelengths: list[str] | None = None,
channel_kwargs: list[dict[str, Any]] | None = None,
omero_kwargs: dict[str, Any] | None = None,
) -> tuple[Multiscale, Omero]:
"""Create a image metadata object from scratch."""
scale = _compute_scale(axis_order, pixel_sizes, time_spacing)

datasets = []
for level in range(num_levels):
transform = [ScaleCoordinateTransformation(type="scale", scale=scale)]
datasets.append(Dataset(path=str(level), coordinateTransformations=transform))

pixel_sizes = [s * f for s, f in zip(pixel_sizes, scaling_factors, strict=True)]
scale = _compute_scale(axis_order, pixel_sizes, time_spacing)

axes = []
for ax_name in axis_order:
if ax_name in TimeNames.allowed_names():
unit = time_units
ax_type = "time"
elif ax_name in SpaceNames.allowed_names():
unit = pixel_units
ax_type = "space"
else:
unit = None
ax_type = "channel"

print(ax_name, unit, ax_type)
axes.append(Axis(name=ax_name, unit=unit, type=ax_type))

multiscale = Multiscale(axes=axes, datasets=datasets)

if channel_names is not None:
if channel_wavelengths is None:
channel_wavelengths = [None] * len(channel_names)

if channel_kwargs is None:
channel_kwargs = [{}] * len(channel_names)

channels = []
for label, wavelenghts, kwargs in zip(
channel_names, channel_wavelengths, channel_kwargs, strict=True
):
channels.append(Channel(label=label, wavelength_id=wavelenghts, **kwargs))

omero_kwargs = {} if omero_kwargs is None else omero_kwargs
omero = Omero(channels=channels, **omero_kwargs)
else:
omero = None

return multiscale, omero


def create_image_metadata(
axis_order: list[str] = ("t", "c", "z", "y", "x"),
pixel_sizes: tuple[float, float, float] = (1.0, 1.0, 1.0),
scaling_factors: tuple[float, float, float] = (1.0, 2.0, 2.0),
pixel_units: SpaceUnits | str = SpaceUnits.micrometer,
time_spacing: float = 1.0,
time_units: TimeUnits | str = TimeUnits.s,
num_levels: int = 5,
name: str | None = None,
channel_names: list[str] | None = None,
channel_wavelengths: list[str] | None = None,
channel_kwargs: list[dict[str, Any]] | None = None,
omero_kwargs: dict[str, Any] | None = None,
version: str = "0.4",
) -> FractalImageMeta:
pass
"""Create a image metadata object from scratch."""
if len(channel_names) != len(set(channel_names)):
raise ValueError("Channel names must be unique.")

mulitscale, omero = _create_image_metadata(
axis_order=axis_order,
pixel_sizes=pixel_sizes,
scaling_factors=scaling_factors,
pixel_units=pixel_units,
time_spacing=time_spacing,
time_units=time_units,
num_levels=num_levels,
channel_names=channel_names,
channel_wavelengths=channel_wavelengths,
channel_kwargs=channel_kwargs,
omero_kwargs=omero_kwargs,
)
return FractalImageMeta(
version=version,
name=name,
multiscale=mulitscale,
omero=omero,
)


def create_label_metadata(
version: str,
name: str,
axis_order: list[str] = ("t", "z", "y", "x"),
pixel_sizes: tuple[float, float, float] = (1.0, 1.0, 1.0),
pixel_units: str = "micrometer",
scaling_factors: tuple[float, float, float] = (1.0, 2.0, 2.0),
pixel_units: SpaceUnits | str = SpaceUnits.micrometer,
time_spacing: float = 1.0,
time_units: str = "second",
time_units: TimeUnits | str = TimeUnits.s,
num_levels: int = 5,
name: str | None = None,
version: str = "0.4",
) -> FractalLabelMeta:
pass
"""Create a label metadata object from scratch."""
multiscale, _ = _create_image_metadata(
axis_order=axis_order,
pixel_sizes=pixel_sizes,
scaling_factors=scaling_factors,
pixel_units=pixel_units,
time_spacing=time_spacing,
time_units=time_units,
num_levels=num_levels,
)
return FractalLabelMeta(
version=version,
name=name,
multiscale=multiscale,
)


def remove_axis_from_metadata(
Expand Down
82 changes: 82 additions & 0 deletions tests/ngff_meta/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np


class TestUtils:
def test_create_fractal_meta_with_t(self):
from ngio.ngff_meta.utils import create_image_metadata

meta = create_image_metadata(
axis_order=("t", "c", "z", "y", "x"),
pixel_sizes=(1.0, 1.0, 1.0),
scaling_factors=(1.0, 2.0, 2.0),
pixel_units="micrometer",
time_spacing=1.0,
time_units="s",
num_levels=5,
name="test",
channel_names=["DAPI", "nanog", "Lamin B1"],
channel_wavelengths=["A01_C01", "A02_C02", "A03_C03"],
channel_kwargs=None,
omero_kwargs=None,
version="0.4",
)

assert meta.get_channel_names() == ["DAPI", "nanog", "Lamin B1"]
assert meta.pixel_size(level=0) == [1.0, 1.0, 1.0]
assert meta.scale(level=0) == [1.0, 1.0, 1.0, 1.0, 1.0]

assert meta.pixel_size(level="2") == [1.0, 4.0, 4.0]
assert meta.scale(level="2") == [1.0, 1.0, 1.0, 4.0, 4.0]

assert meta.num_levels == 5

def test_create_fractal_meta(self):
from ngio.ngff_meta.utils import create_image_metadata

meta = create_image_metadata(
axis_order=("c", "z", "y", "x"),
pixel_sizes=(1.0, 1.0, 1.0),
scaling_factors=(1.0, 2.0, 2.0),
pixel_units="micrometer",
time_spacing=1.0,
time_units="s",
num_levels=5,
name="test",
channel_names=["DAPI", "nanog", "Lamin B1"],
channel_wavelengths=["A01_C01", "A02_C02", "A03_C03"],
channel_kwargs=None,
omero_kwargs=None,
version="0.4",
)

assert meta.get_channel_names() == ["DAPI", "nanog", "Lamin B1"]
assert meta.pixel_size(level=0) == [1.0, 1.0, 1.0]
assert meta.scale(level=0) == [1.0, 1.0, 1.0, 1.0]

assert meta.pixel_size(level="2") == [1.0, 4.0, 4.0]
assert meta.scale(level="2") == [1.0, 1.0, 4.0, 4.0]

assert meta.num_levels == 5

def test_create_fractal_label_meta(self):
from ngio.ngff_meta.utils import create_label_metadata

meta = create_label_metadata(
axis_order=("t", "z", "y", "x"),
pixel_sizes=(1.0, 1.0, 1.0),
scaling_factors=(1.0, 2.0, 2.0),
pixel_units="micrometer",
time_spacing=1.0,
time_units="s",
num_levels=5,
name="test",
version="0.4",
)

assert meta.pixel_size(level=0) == [1.0, 1.0, 1.0]
assert meta.scale(level=0) == [1.0, 1.0, 1.0, 1.0]

assert meta.pixel_size(level="2") == [1.0, 4.0, 4.0]
assert meta.scale(level="2") == [1.0, 1.0, 4.0, 4.0]

assert meta.num_levels == 5

0 comments on commit 25bdd41

Please sign in to comment.