diff --git a/CHANGELOG.md b/CHANGELOG.md index d3f9fd22..8daa873c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning][]. ## [0.1.0] - tbd +### Added + +- Pushed `get_extent` functionality upstream to `spatialdata` (#162) + +### Fixed + +- + ## [0.0.5] - 2023-10-02 ### Added diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f58d7e04..c941cae1 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -19,7 +19,8 @@ from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from pandas.api.types import is_categorical_dtype from spatial_image import SpatialImage -from spatialdata._logging import logger as logg +from spatialdata._core.data_extent import get_extent +from spatialdata.transformations.operations import get_transformation from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot.pl.render import ( @@ -40,12 +41,10 @@ ) from spatialdata_plot.pl.utils import ( _get_cs_contents, - _get_extent, _maybe_set_colors, _mpl_ax_contains_elements, _prepare_cmap_norm, _prepare_params_plot, - _robust_transform, _set_outline, save_fig, ) @@ -216,6 +215,8 @@ def render_shapes( na_color=na_color, # type: ignore[arg-type] **kwargs, ) + if isinstance(elements, str): + elements = [elements] outline_params = _set_outline(outline, outline_width, outline_color) sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams( elements=elements, @@ -285,12 +286,15 @@ def render_points( sdata = self._copy() sdata = _verify_plotting_tree(sdata) n_steps = len(sdata.plotting_tree.keys()) + cmap_params = _prepare_cmap_norm( cmap=cmap, norm=norm, na_color=na_color, # type: ignore[arg-type] **kwargs, ) + if isinstance(elements, str): + elements = [elements] sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams( elements=elements, color=color, @@ -370,6 +374,8 @@ def render_images( **kwargs, ) + if isinstance(elements, str): + elements = [elements] sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams( elements=elements, channel=channel, @@ -450,6 +456,8 @@ def render_labels( na_color=na_color, # type: ignore[arg-type] **kwargs, ) + if isinstance(elements, str): + elements = [elements] sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams( elements=elements, color=color, @@ -552,12 +560,12 @@ def show( raise TypeError("All titles must be strings.") # get original axis extent for later comparison - x_min_orig, x_max_orig = (np.inf, -np.inf) - y_min_orig, y_max_orig = (np.inf, -np.inf) + ax_x_min, ax_x_max = (np.inf, -np.inf) + ax_y_min, ax_y_max = (np.inf, -np.inf) if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax): - x_min_orig, x_max_orig = ax.get_xlim() - y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left + ax_x_min, ax_x_max = ax.get_xlim() + ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left # handle coordinate system coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems @@ -568,50 +576,6 @@ def show( if cs not in sdata.coordinate_systems: raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}") - # Check if user specified only certain elements to be plotted - cs_contents = _get_cs_contents(sdata) - elements_to_be_rendered = [] - for cmd, params in render_cmds.items(): - if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114 - if params.elements is not None: - elements_to_be_rendered += ( - [params.elements] if isinstance(params.elements, str) else params.elements - ) - elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114 - if params.elements is not None: - elements_to_be_rendered += ( - [params.elements] if isinstance(params.elements, str) else params.elements - ) - elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114 - if params.elements is not None: - elements_to_be_rendered += ( - [params.elements] if isinstance(params.elements, str) else params.elements - ) - elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102 - if params.elements is not None: - elements_to_be_rendered += ( - [params.elements] if isinstance(params.elements, str) else params.elements - ) - - extent = _get_extent( - sdata=sdata, - has_images="render_images" in render_cmds, - has_labels="render_labels" in render_cmds, - has_points="render_points" in render_cmds, - has_shapes="render_shapes" in render_cmds, - elements=elements_to_be_rendered, - coordinate_systems=coordinate_systems, - ) - - # Use extent to filter out coordinate system without the relevant elements - valid_cs = [] - for cs in coordinate_systems: - if cs in extent: - valid_cs.append(cs) - else: - logg.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.") - coordinate_systems = valid_cs - # set up canvas fig_params, scalebar_params = _prepare_params_plot( num_panels=len(coordinate_systems), @@ -633,32 +597,25 @@ def show( colorbar=colorbar, ) + cs_contents = _get_cs_contents(sdata) + # go through tree + for i, cs in enumerate(coordinate_systems): sdata = self._copy() - # properly transform all elements to the current coordinate system - members = cs_contents.query(f"cs == '{cs}'") - - if members["has_images"].values[0]: - for key in sdata.images: - sdata.images[key] = _robust_transform(sdata.images[key], cs) - - if members["has_labels"].values[0]: - for key in sdata.labels: - sdata.labels[key] = _robust_transform(sdata.labels[key], cs) - - if members["has_points"].values[0]: - for key in sdata.points: - sdata.points[key] = _robust_transform(sdata.points[key], cs) - - if members["has_shapes"].values[0]: - for key in sdata.shapes: - sdata.shapes[key] = _robust_transform(sdata.shapes[key], cs) - + _, has_images, has_labels, has_points, has_shapes = ( + cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist() + ) ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i] + wants_images = False + wants_labels = False + wants_points = False + wants_shapes = False + wanted_elements = [] + for cmd, params in render_cmds.items(): - if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: + if cmd == "render_images" and has_images: _render_images( sdata=sdata, render_params=params, @@ -667,9 +624,18 @@ def show( fig_params=fig_params, scalebar_params=scalebar_params, legend_params=legend_params, - # extent=extent[cs], ) - elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: + wants_images = True + wanted_images = params.elements if params.elements is not None else list(sdata.images.keys()) + wanted_elements.extend( + [ + image + for image in wanted_images + if cs in set(get_transformation(sdata.images[image], get_all=True).keys()) + ] + ) + + elif cmd == "render_shapes" and has_shapes: _render_shapes( sdata=sdata, render_params=params, @@ -679,8 +645,17 @@ def show( scalebar_params=scalebar_params, legend_params=legend_params, ) + wants_shapes = True + wanted_shapes = params.elements if params.elements is not None else list(sdata.shapes.keys()) + wanted_elements.extend( + [ + shape + for shape in wanted_shapes + if cs in set(get_transformation(sdata.shapes[shape], get_all=True).keys()) + ] + ) - elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: + elif cmd == "render_points" and has_points: _render_points( sdata=sdata, render_params=params, @@ -690,8 +665,17 @@ def show( scalebar_params=scalebar_params, legend_params=legend_params, ) + wants_points = True + wanted_points = params.elements if params.elements is not None else list(sdata.points.keys()) + wanted_elements.extend( + [ + point + for point in wanted_points + if cs in set(get_transformation(sdata.points[point], get_all=True).keys()) + ] + ) - elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: + elif cmd == "render_labels" and has_labels: if sdata.table is not None and isinstance(params.color, str): colors = sc.get.obs_df(sdata.table, params.color) if is_categorical_dtype(colors): @@ -710,33 +694,46 @@ def show( scalebar_params=scalebar_params, legend_params=legend_params, ) + wants_labels = True + wanted_labels = params.elements if params.elements is not None else list(sdata.labels.keys()) + wanted_elements.extend( + [ + label + for label in wanted_labels + if cs in set(get_transformation(sdata.labels[label], get_all=True).keys()) + ] + ) - if title is not None: - if len(title) == 1: - t = title[0] - else: - try: - t = title[i] - except IndexError as e: - raise IndexError("The number of titles must match the number of coordinate systems.") from e - else: + if title is None: t = cs + elif len(title) == 1: + t = title[0] + else: + try: + t = title[i] + except IndexError as e: + raise IndexError("The number of titles must match the number of coordinate systems.") from e ax.set_title(t) ax.set_aspect("equal") - if any( - [ - cs_contents.query(f"cs == '{cs}'")["has_images"][0], - cs_contents.query(f"cs == '{cs}'")["has_labels"][0], - cs_contents.query(f"cs == '{cs}'")["has_points"][0], - cs_contents.query(f"cs == '{cs}'")["has_shapes"][0], - ] - ): + extent = get_extent( + sdata, + coordinate_system=cs, + has_images=has_images and wants_images, + has_labels=has_labels and wants_labels, + has_points=has_points and wants_points, + has_shapes=has_shapes and wants_shapes, + elements=wanted_elements, + ) + cs_x_min, cs_x_max = extent["x"] + cs_y_min, cs_y_max = extent["y"] + + if any([has_images, has_labels, has_points, has_shapes]): # If the axis already has limits, only expand them but not overwrite - x_min = min(x_min_orig, extent[cs][0]) - pad_extent - x_max = max(x_max_orig, extent[cs][1]) + pad_extent - y_min = min(y_min_orig, extent[cs][2]) - pad_extent - y_max = max(y_max_orig, extent[cs][3]) + pad_extent + x_min = min(ax_x_min, cs_x_min) - pad_extent + x_max = max(ax_x_max, cs_x_max) + pad_extent + y_min = min(ax_y_min, cs_y_min) - pad_extent + y_max = max(ax_y_max, cs_y_max) + pad_extent ax.set_xlim(x_min, x_max) ax.set_ylim(y_max, y_min) # (0, 0) is top-left @@ -747,5 +744,4 @@ def show( # https://stackoverflow.com/a/64523765 if not hasattr(sys, "ps1"): plt.show() - return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 106f404b..c411bb0e 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1,12 +1,13 @@ from __future__ import annotations -from collections.abc import Sequence +from collections import abc from copy import copy from typing import Union import dask import geopandas as gpd import matplotlib +import matplotlib.transforms as mtransforms import numpy as np import pandas as pd import scanpy as sc @@ -21,6 +22,9 @@ Labels2DModel, PointsModel, ) +from spatialdata.transformations import ( + get_transformation, +) from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import ( @@ -45,7 +49,7 @@ ) from spatialdata_plot.pp.utils import _get_instance_key, _get_region_key -_Normalize = Union[Normalize, Sequence[Normalize]] +_Normalize = Union[Normalize, abc.Sequence[Normalize]] def _render_shapes( @@ -137,6 +141,14 @@ def _render_shapes( cax = ax.add_collection(_cax) + # Apply the transformation to the PatchCollection's paths + trans = get_transformation(sdata_filt.shapes[e], get_all=True)[coordinate_system] + affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + trans = mtransforms.Affine2D(matrix=affine_trans) + + for path in _cax.get_paths(): + path.vertices = trans.transform(path.vertices) + # Using dict.fromkeys here since set returns in arbitrary order # remove the color of NaN values, else it might be assigned to a category # order of color in the palette should agree to order of occurence @@ -260,6 +272,14 @@ def _render_points( # **kwargs, ) cax = ax.add_collection(_cax) + + trans = get_transformation(sdata.points[e], get_all=True)[coordinate_system] + affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + trans = mtransforms.Affine2D(matrix=affine_trans) + + for path in _cax.get_paths(): + path.vertices = trans.transform(path.vertices) + if not ( len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color) ): @@ -311,11 +331,11 @@ def _render_images( if elements is None: elements = list(sdata_filt.images.keys()) - images = [sdata.images[e] for e in elements] - for img, img_key in zip(images, elements): + for e in elements: + img = sdata.images[e] if not isinstance(img, spatial_image.SpatialImage): img = Image2DModel.parse(img["scale0"].ds.to_array().squeeze(axis=0)) - logger.warning(f"Multi-scale images not yet supported, using scale0 of multi-scale image '{img_key}'.") + logger.warning(f"Multi-scale images not yet supported, using scale0 of multi-scale image '{e}'.") if render_params.channel is None: channels = img.coords["c"].values @@ -341,6 +361,12 @@ def _render_images( if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels: raise ValueError("If 'cmap' is provided, its length must match the number of channels.") + # prepare transformations + trans = get_transformation(sdata.images[e], get_all=True)[coordinate_system] + affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + trans = mtransforms.Affine2D(matrix=affine_trans) + trans_data = trans + ax.transData + # 1) Image has only 1 channel if n_channels == 1 and not isinstance(render_params.cmap_params, list): layer = img.sel(c=channels).squeeze() @@ -358,11 +384,12 @@ def _render_images( else: cmap = _get_linear_colormap([render_params.palette], "k")[0] - ax.imshow( + im = ax.imshow( layer, # get rid of the channel dimension cmap=cmap, alpha=render_params.alpha, ) + im.set_transform(trans_data) # 2) Image has any number of channels but 1 else: @@ -387,7 +414,11 @@ def _render_images( # 2A) Image has 3 channels, no palette/cmap info -> use RGB if n_channels == 3 and render_params.palette is None and not got_multiple_cmaps: - ax.imshow(np.stack([layers[c] for c in channels], axis=-1), alpha=render_params.alpha) + im = ax.imshow( + np.stack([layers[c] for c in channels], axis=-1), + alpha=render_params.alpha, + ) + im.set_transform(trans_data) # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors elif render_params.palette is None and not got_multiple_cmaps: @@ -405,10 +436,11 @@ def _render_images( # Remove alpha channel so we can overwrite it from render_params.alpha colored = colored[:, :, :3] - ax.imshow( + im = ax.imshow( colored, alpha=render_params.alpha, ) + im.set_transform(trans_data) # 2C) Image has n channels and palette info elif render_params.palette is not None and not got_multiple_cmaps: @@ -423,10 +455,11 @@ def _render_images( # Remove alpha channel so we can overwrite it from render_params.alpha colored = colored[:, :, :3] - ax.imshow( + im = ax.imshow( colored, alpha=render_params.alpha, ) + im.set_transform(trans_data) elif render_params.palette is None and got_multiple_cmaps: channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr] @@ -437,10 +470,11 @@ def _render_images( # Remove alpha channel so we can overwrite it from render_params.alpha colored = colored[:, :, :3] - ax.imshow( + im = ax.imshow( colored, alpha=render_params.alpha, ) + im.set_transform(trans_data) elif render_params.palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") @@ -473,12 +507,11 @@ def _render_labels( if elements is None: elements = list(sdata_filt.labels.keys()) - labels = [sdata.labels[e] for e in elements] - - for label, label_key in zip(labels, elements): + for e in elements: + label = sdata_filt.labels[e] if not isinstance(label, spatial_image.SpatialImage): label = Labels2DModel.parse(label["scale0"].ds.to_array().squeeze(axis=0)) - logger.warning(f"Multi-scale labels not yet supported, using scale0 of multi-scale label '{label_key}'.") + logger.warning(f"Multi-scale labels not yet supported, using scale0 of multi-scale label '{e}'.") if sdata.table is None: instance_id = np.unique(label) @@ -487,16 +520,21 @@ def _render_labels( instance_key = _get_instance_key(sdata) region_key = _get_region_key(sdata) - table = sdata.table[sdata.table.obs[region_key].isin([label_key])] + table = sdata.table[sdata.table.obs[region_key].isin([e])] # get instance id based on subsetted table instance_id = table.obs[instance_key].values + trans = get_transformation(sdata.labels[e], get_all=True)[coordinate_system] + affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + trans = mtransforms.Affine2D(matrix=affine_trans) + trans_data = trans + ax.transData + # get color vector (categorical or continuous) color_source_vector, color_vector, categorical = _set_color_source_vec( sdata=sdata_filt, - element=sdata_filt.labels[label_key], - element_name=label_key, + element=sdata_filt.labels[e], + element_name=e, value_to_plot=render_params.color, layer=render_params.layer, groups=render_params.groups, @@ -526,8 +564,8 @@ def _render_labels( norm=render_params.cmap_params.norm if not categorical else None, alpha=render_params.fill_alpha, origin="lower", - # zorder=3, ) + _cax.set_transform(trans_data) cax = ax.add_image(_cax) # Then overlay the contour @@ -549,8 +587,8 @@ def _render_labels( norm=render_params.cmap_params.norm if not categorical else None, alpha=render_params.outline_alpha, origin="lower", - # zorder=4, ) + _cax.set_transform(trans_data) cax = ax.add_image(_cax) else: @@ -573,8 +611,8 @@ def _render_labels( norm=render_params.cmap_params.norm if not categorical else None, alpha=render_params.fill_alpha, origin="lower", - # zorder=4, ) + _cax.set_transform(trans_data) cax = ax.add_image(_cax) _ = _decorate_axs( diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index b8b8b7f5..5a66b667 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -6,7 +6,7 @@ from functools import partial from pathlib import Path from types import MappingProxyType -from typing import Any, Literal +from typing import Any import matplotlib import matplotlib.patches as mpatches @@ -43,17 +43,15 @@ from scanpy import settings from scanpy.plotting._tools.scatterplots import _add_categorical_legend from scanpy.plotting.palettes import default_20, default_28, default_102 -from shapely.geometry import LineString, Point, Polygon +from shapely.geometry import LineString, Polygon from skimage.color import label2rgb from skimage.morphology import erosion, square from skimage.segmentation import find_boundaries from skimage.util import map_array -from spatialdata import transform from spatialdata._core.query.relational_query import _locate_value, get_values from spatialdata._logging import logger as logging from spatialdata._types import ArrayLike -from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement -from spatialdata.transformations import get_transformation +from spatialdata.models import Image2DModel, SpatialElement from spatialdata_plot.pl.render_params import ( CmapParams, @@ -92,8 +90,6 @@ def _prepare_params_plot( scalebar_dx: float | Sequence[float] | None = None, scalebar_units: str | Sequence[str] | None = None, ) -> tuple[FigParams, ScalebarParams]: - # len(list(itertools.product(*iter_panels))) - # handle axes and size wspace = 0.75 / rcParams["figure.figsize"][0] + 0.02 if wspace is None else wspace figsize = rcParams["figure.figsize"] if figsize is None else figsize @@ -103,20 +99,21 @@ def _prepare_params_plot( num_panels=num_panels, hspace=hspace, wspace=wspace, ncols=ncols, dpi=dpi, figsize=figsize ) axs: None | Sequence[Axes] = [plt.subplot(grid[c]) for c in range(num_panels)] - elif num_panels > 1 and ax is not None: - if len(ax) != num_panels: + elif num_panels > 1: + if not isinstance(ax, Sequence): + raise TypeError(f"Expected `ax` to be a `Sequence`, but got {type(ax).__name__}") + if ax is not None and len(ax) != num_panels: raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.") if fig is None: raise ValueError( f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified." ) - assert isinstance(ax, Sequence), f"Invalid type of `ax`: {type(ax)}, expected `Sequence`." + assert ax is None or isinstance(ax, Sequence), f"Invalid type of `ax`: {type(ax)}, expected `Sequence`." axs = ax else: axs = None if ax is None: fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True) - # set scalebar if scalebar_dx is not None: scalebar_dx, scalebar_units = _get_scalebar(scalebar_dx, scalebar_units, num_panels) @@ -141,10 +138,10 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame: for cs_name, element_ids in cs_mapping.items(): # determine if coordinate system has the respective elements - cs_has_images = bool(any((e in sdata.images) for e in element_ids)) - cs_has_labels = bool(any((e in sdata.labels) for e in element_ids)) - cs_has_points = bool(any((e in sdata.points) for e in element_ids)) - cs_has_shapes = bool(any((e in sdata.shapes) for e in element_ids)) + cs_has_images = any(e in sdata.images for e in element_ids) + cs_has_labels = any(e in sdata.labels for e in element_ids) + cs_has_points = any(e in sdata.points for e in element_ids) + cs_has_shapes = any(e in sdata.shapes for e in element_ids) cs_contents = pd.concat( [ @@ -283,242 +280,6 @@ def assign_fill_and_outline_to_row( ) -def _get_extent( - sdata: sd.SpatialData, - coordinate_systems: Sequence[str] | str | None = None, - has_images: bool = True, - has_labels: bool = True, - has_points: bool = True, - has_shapes: bool = True, - elements: Iterable[Any] | None = None, - share_extent: bool = False, -) -> dict[str, tuple[int, int, int, int]]: - """Return the extent of all elements in their respective coordinate systems. - - Parameters - ---------- - sdata - The sd.SpatialData object to retrieve the extent from - has_images - Flag indicating whether to consider images when calculating the extent - has_labels - Flag indicating whether to consider labels when calculating the extent - has_points - Flag indicating whether to consider points when calculating the extent - has_shapes - Flag indicating whether to consider shapes when calculating the extent - elements - Optional list of element names to be considered. When None, all are used. - share_extent - Flag indicating whether to use the same extent for all coordinate systems - - Returns - ------- - A dict of tuples with the shape (xmin, xmax, ymin, ymax). The keys of the - dict are the coordinate_system keys. - - """ - extent: dict[str, dict[str, Sequence[int]]] = {} - cs_mapping = _get_coordinate_system_mapping(sdata) - cs_contents = _get_cs_contents(sdata) - - if elements is None: # to shut up ruff - elements = [] - - if not isinstance(elements, list): - raise ValueError(f"Invalid type of `elements`: {type(elements)}, expected `list`.") - - if coordinate_systems is not None: - if isinstance(coordinate_systems, str): - coordinate_systems = [coordinate_systems] - cs_contents = cs_contents[cs_contents["cs"].isin(coordinate_systems)] - cs_mapping = {k: v for k, v in cs_mapping.items() if k in coordinate_systems} - - for cs_name, element_ids in cs_mapping.items(): - extent[cs_name] = {} - if len(elements) > 0: - element_ids = [e for e in element_ids if e in elements] - - def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[int]: - tmp = element.copy() - if len(tmp.shape) == 3: - x_idx = 2 - y_idx = 1 - elif len(tmp.shape) == 2: - x_idx = 1 - y_idx = 0 - - transformations = get_transformation(tmp, to_coordinate_system=cs_name) - transformations = _flatten_transformation_sequence(transformations) - - if len(transformations) == 1 and isinstance( - transformations[0], sd.transformations.transformations.Identity - ): - result = (0, tmp.shape[x_idx], 0, tmp.shape[y_idx]) - - else: - origin = { - "x": 0, - "y": 0, - } - for t in transformations: - if isinstance(t, sd.transformations.transformations.Translation): - tmp = _translate_image(image=tmp, translation=t) - - for idx, ax in enumerate(t.axes): - origin["x"] += t.translation[idx] if ax == "x" else 0 - origin["y"] += t.translation[idx] if ax == "y" else 0 - - else: - tmp = transform(tmp, t) - - if isinstance(t, sd.transformations.transformations.Scale): - for idx, ax in enumerate(t.axes): - origin["x"] *= t.scale[idx] if ax == "x" else 1 - origin["y"] *= t.scale[idx] if ax == "y" else 1 - - elif isinstance(t, sd.transformations.transformations.Affine): - pass - - result = (origin["x"], tmp.shape[x_idx], origin["y"], tmp.shape[y_idx]) - - del tmp - return result - - if has_images and cs_contents.query(f"cs == '{cs_name}'")["has_images"][0]: - for images_key in sdata.images: - for e_id in element_ids: - if images_key == e_id: - if isinstance(sdata.images[e_id], spatial_image.SpatialImage): - extent[cs_name][e_id] = _get_extent_after_transformations(sdata.images[e_id], cs_name) - else: - img = Image2DModel.parse(sdata.images[e_id]["scale0"].ds.to_array().squeeze(axis=0)) - extent[cs_name][e_id] = _get_extent_after_transformations(img, cs_name) - - if has_labels and cs_contents.query(f"cs == '{cs_name}'")["has_labels"][0]: - for labels_key in sdata.labels: - for e_id in element_ids: - if labels_key == e_id: - if isinstance(sdata.labels[e_id], spatial_image.SpatialImage): - extent[cs_name][e_id] = _get_extent_after_transformations(sdata.labels[e_id], cs_name) - else: - label = Labels2DModel.parse(sdata.labels[e_id]["scale0"].ds.to_array().squeeze(axis=0)) - extent[cs_name][e_id] = _get_extent_after_transformations(label, cs_name) - - if has_shapes and cs_contents.query(f"cs == '{cs_name}'")["has_shapes"][0]: - for shapes_key in sdata.shapes: - for e_id in element_ids: - if shapes_key == e_id: - - def get_point_bb( - point: Point, radius: int, method: Literal["topleft", "bottomright"], buffer: int = 0 - ) -> Point: - x, y = point.coords[0] - if method == "topleft": - point_bb = Point(x - radius - buffer, y - radius - buffer) - else: - point_bb = Point(x + radius + buffer, y + radius + buffer) - - return point_bb - - y_dims = [] - x_dims = [] - - # Split by Point and Polygon: - tmp_points = sdata.shapes[e_id][ - sdata.shapes[e_id]["geometry"].apply( - lambda geom: (geom.geom_type == "Point" and not geom.is_empty) - ) - ] - tmp_polygons = sdata.shapes[e_id][ - sdata.shapes[e_id]["geometry"].apply( - lambda geom: (geom.geom_type in ["Polygon", "MultiPolygon"] and not geom.is_empty) - ) - ] - - if not tmp_points.empty: - tmp_points["point_topleft"] = tmp_points.apply( - lambda row: get_point_bb(row["geometry"], row["radius"], "topleft"), - axis=1, - ) - tmp_points["point_bottomright"] = tmp_points.apply( - lambda row: get_point_bb(row["geometry"], row["radius"], "bottomright"), - axis=1, - ) - xmin_tl, ymin_tl, xmax_tl, ymax_tl = tmp_points["point_topleft"].total_bounds - xmin_br, ymin_br, xmax_br, ymax_br = tmp_points["point_bottomright"].total_bounds - y_dims += [min(ymin_tl, ymin_br), max(ymax_tl, ymax_br)] - x_dims += [min(xmin_tl, xmin_br), max(xmax_tl, xmax_br)] - - if not tmp_polygons.empty: - xmin, ymin, xmax, ymax = tmp_polygons.total_bounds - y_dims += [ymin, ymax] - x_dims += [xmin, xmax] - - del tmp_points - del tmp_polygons - - xmin = np.min(x_dims) - xmax = np.max(x_dims) - ymin = np.min(y_dims) - ymax = np.max(y_dims) - - extent[cs_name][e_id] = [xmin, xmax, ymin, ymax] - - transformations = get_transformation(sdata.shapes[e_id], to_coordinate_system=cs_name) - transformations = _flatten_transformation_sequence(transformations) - - for t in transformations: - if isinstance(t, sd.transformations.transformations.Translation): - for idx, ax in enumerate(t.axes): - extent[cs_name][e_id][0] += t.translation[idx] if ax == "x" else 0 # type: ignore - extent[cs_name][e_id][1] += t.translation[idx] if ax == "x" else 0 # type: ignore - extent[cs_name][e_id][2] += t.translation[idx] if ax == "y" else 0 # type: ignore - extent[cs_name][e_id][3] += t.translation[idx] if ax == "y" else 0 # type: ignore - - else: - if isinstance(t, sd.transformations.transformations.Scale): - for idx, ax in enumerate(t.axes): - extent[cs_name][e_id][1] *= t.scale[idx] if ax == "x" else 1 # type: ignore - extent[cs_name][e_id][3] *= t.scale[idx] if ax == "y" else 1 # type: ignore - - elif isinstance(t, sd.transformations.transformations.Affine): - pass - - if has_points and cs_contents.query(f"cs == '{cs_name}'")["has_points"][0]: - for points_key in sdata.points: - for e_id in element_ids: - if points_key == e_id: - tmp = sdata.points[points_key] - xmin = tmp["x"].min().compute() - xmax = tmp["x"].max().compute() - ymin = tmp["y"].min().compute() - ymax = tmp["y"].max().compute() - extent[cs_name][e_id] = [xmin, xmax, ymin, ymax] - - cswise_extent = {} - for cs_name, cs_contents in extent.items(): - if len(cs_contents) > 0: - xmin = min([v[0] for v in cs_contents.values()]) - xmax = max([v[1] for v in cs_contents.values()]) - ymin = min([v[2] for v in cs_contents.values()]) - ymax = max([v[3] for v in cs_contents.values()]) - cswise_extent[cs_name] = (xmin, xmax, ymin, ymax) - - if share_extent: - global_extent = {} - if len(cs_contents) > 0: - xmin = min([v[0] for v in cswise_extent.values()]) - xmax = max([v[1] for v in cswise_extent.values()]) - ymin = min([v[2] for v in cswise_extent.values()]) - ymax = max([v[3] for v in cswise_extent.values()]) - for cs_name in cswise_extent: - global_extent[cs_name] = (xmin, xmax, ymin, ymax) - return global_extent - - return cswise_extent - - def _panel_grid( num_panels: int, hspace: float, @@ -578,10 +339,8 @@ def _prepare_cmap_norm( cmap = copy(matplotlib.colormaps[rcParams["image.cmap"] if cmap is None else cmap]) cmap.set_bad("lightgray" if na_color is None else na_color) - if isinstance(norm, Normalize): + if isinstance(norm, Normalize) or not norm: pass # TODO - elif not norm: - pass elif vcenter is None: norm = Normalize(vmin=vmin, vmax=vmax) else: @@ -598,17 +357,17 @@ def _set_outline( ) -> OutlineParams: # Type checks for outline_width if isinstance(outline_width, int): - outline_width = float(outline_width) + outline_width = outline_width if not isinstance(outline_width, float): raise TypeError(f"Invalid type of `outline_width`: {type(outline_width)}, expected `float`.") if outline_width == 0.0: outline = False if outline_width < 0.0: logging.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1)*outline_width}") - outline_width = (-1) * outline_width + outline_width *= -1 # the default black and white colors can be changed using the contour_config parameter - if (len(outline_color) == 3 or len(outline_color) == 4) and all(isinstance(c, float) for c in outline_color): + if len(outline_color) in {3, 4} and all(isinstance(c, float) for c in outline_color): outline_color = matplotlib.colors.to_hex(outline_color) if outline: @@ -778,18 +537,15 @@ def _get_colors_for_categorical_obs( elif len(rcParams["axes.prop_cycle"].by_key()["color"]) >= len_cat: cc = rcParams["axes.prop_cycle"]() palette = [next(cc)["color"] for _ in range(len_cat)] + elif len_cat <= 20: + palette = default_20 + elif len_cat <= 28: + palette = default_28 + elif len_cat <= len(default_102): # 103 colors + palette = default_102 else: - if len_cat <= 20: - palette = default_20 - elif len_cat <= 28: - palette = default_28 - elif len_cat <= len(default_102): # 103 colors - palette = default_102 - else: - palette = ["grey" for _ in range(len_cat)] - logging.info( - "input has more than 103 categories. Uniform " "'grey' color will be used for all categories." - ) + palette = ["grey" for _ in range(len_cat)] + logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.") # otherwise, single channels turn out grey color_idx = np.linspace(0, 1, len_cat) if len_cat > 1 else [0.7] @@ -1045,12 +801,6 @@ def _decorate_axs( # TODO: na_in_legend should have some effect here plt.colorbar(cax, ax=ax, pad=0.01, fraction=0.08, aspect=30) - # if img is not None: - # ax.imshow(img, cmap=img_cmap, alpha=img_alpha) - # else: - # ax.set_aspect("equal") - # ax.invert_yaxis() - if isinstance(scalebar_dx, list) and isinstance(scalebar_units, list): scalebar = ScaleBar(scalebar_dx, units=scalebar_units, **scalebar_kwargs) ax.add_artist(scalebar) @@ -1184,11 +934,7 @@ def _translate_image( image: spatial_image.SpatialImage, translation: sd.transformations.transformations.Translation, ) -> spatial_image.SpatialImage: - shifts: dict[str, int] = {} - - for idx, axis in enumerate(translation.axes): - shifts[axis] = int(translation.translation[idx]) - + shifts: dict[str, int] = {axis: int(translation.translation[idx]) for idx, axis in enumerate(translation.axes)} img = image.values.copy() shifted_channels = [] @@ -1228,81 +974,24 @@ def _convert_polygon_to_linestrings(polygon: Polygon) -> list[LineString]: return [list(ls.coords) for ls in linestrings] -def _flatten_transformation_sequence( - transformation_sequence: list[sd.transformations.transformations.Sequence], -) -> list[sd.transformations.transformations.Sequence]: - if isinstance(transformation_sequence, sd.transformations.transformations.Sequence): - transformations = list(transformation_sequence.transformations) - found_bottom_of_tree = False - while not found_bottom_of_tree: - if all(not isinstance(t, sd.transformations.transformations.Sequence) for t in transformations): - found_bottom_of_tree = True - else: - for idx, t in enumerate(transformations): - if isinstance(t, sd.transformations.transformations.Sequence): - transformations.pop(idx) - transformations += t.transformations - - return transformations - - if isinstance(transformation_sequence, sd.transformations.transformations.BaseTransformation): - return [transformation_sequence] - - raise TypeError("Parameter 'transformation_sequence' must be a Sequence.") - - -def _robust_transform(element: Any, cs: str) -> Any: - try: - transformations = get_transformation(element, get_all=True) - if cs not in transformations: - return element - transformations = transformations[cs] - transformations = _flatten_transformation_sequence(transformations) - for _, t in enumerate(transformations): - if isinstance(t, sd.transformations.transformations.Translation): - element = _translate_image(image=element, translation=t) - - elif isinstance(t, sd.transformations.transformations.Affine): - # edge case, waiting for Luca to decompose affine into components - # element = transform(element, t) - # new_transformations = get_transformation(element, get_all=True) - # new_transformations = new_transformations[cs] - # new_transformations = _flatten_transformation_sequence(new_transformations) - # seq = new_transformations[:len(new_transformations) - len(transformations)] - # seq = sd.transformations.Sequence(seq) - # set_transformation(element, seq, to_coordinate_system=cs) - # element = _robust_transform(element, cs) - # print(element.shape) - pass - - else: - element = transform(element, t) - - except ValueError as e: - # hack, talk to Luca - raise ValueError("Unable to transform element.") from e - - return element - - def _split_multipolygon_into_outer_and_inner(mp: shapely.MultiPolygon): # type: ignore # https://stackoverflow.com/a/21922058 for geom in mp.geoms: - if geom.geom_type == "Polygon": - exterior_coords = geom.exterior.coords[:] - interior_coords = [] - for interior in geom.interiors: - interior_coords += interior.coords[:] - elif geom.geom_type == "MultiPolygon": + if geom.geom_type == "MultiPolygon": exterior_coords = [] interior_coords = [] for part in geom: epc = _split_multipolygon_into_outer_and_inner(part) # Recursive call exterior_coords += epc["exterior_coords"] interior_coords += epc["interior_coords"] + elif geom.geom_type == "Polygon": + exterior_coords = geom.exterior.coords[:] + interior_coords = [] + for interior in geom.interiors: + interior_coords += interior.coords[:] else: - raise ValueError("Unhandled geometry type: " + repr(geom.type)) + raise ValueError(f"Unhandled geometry type: {repr(geom.type)}") return interior_coords, exterior_coords diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_affine.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_affine.png new file mode 100644 index 00000000..ae096987 Binary files /dev/null and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_affine.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_composition.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_composition.png new file mode 100644 index 00000000..bf95debd Binary files /dev/null and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_composition.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_inverse.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_inverse.png new file mode 100644 index 00000000..4bec4e75 Binary files /dev/null and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_inverse.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_mapaxis.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_mapaxis.png index bf31e4e9..c1c561a7 100644 Binary files a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_mapaxis.png and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_mapaxis.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_overlay.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_overlay.png index f63b889a..6cac3865 100644 Binary files a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_overlay.png and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_overlay.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_rotation.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_rotation.png new file mode 100644 index 00000000..7f9674a4 Binary files /dev/null and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_rotation.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png index a291034b..f0c3d619 100644 Binary files a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_split.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_split.png index d61c1a62..e3f0c8e3 100644 Binary files a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_split.png and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_split.png differ diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_translation.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_translation.png new file mode 100644 index 00000000..fb943a23 Binary files /dev/null and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_translation.png differ diff --git a/tests/pl/test_upstream_plots.py b/tests/pl/test_upstream_plots.py index 7dda9a42..f10c7bf1 100644 --- a/tests/pl/test_upstream_plots.py +++ b/tests/pl/test_upstream_plots.py @@ -1,11 +1,16 @@ +import math + import matplotlib import matplotlib.pyplot as plt import scanpy as sc import spatialdata_plot # noqa: F401 from spatialdata import SpatialData from spatialdata.transformations import ( + Affine, MapAxis, Scale, + Sequence, + Translation, set_transformation, ) @@ -47,6 +52,82 @@ def test_plot_can_render_transformations_raccoon_mapaxis(self, sdata_raccoon: Sp sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show() + def test_plot_can_render_transformations_raccoon_rotation(self, sdata_raccoon: SpatialData): + theta = math.pi / 6 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + + set_transformation(sdata_raccoon.images["raccoon"], rotation, to_coordinate_system="global") + + sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show() + + def test_plot_can_render_transformations_raccoon_translation(self, sdata_raccoon: SpatialData): + translation = Translation([500, 300], axes=("x", "y")) + set_transformation(sdata_raccoon.images["raccoon"], translation, to_coordinate_system="global") + + sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show() + + def test_plot_can_render_transformations_raccoon_affine(self, sdata_raccoon: SpatialData): + theta = math.pi / 6 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + scale = Scale([2.0], axes=("x",)) + sequence = Sequence([rotation, scale]) + + set_transformation(sdata_raccoon.images["raccoon"], sequence, to_coordinate_system="global") + + sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show() + + def test_plot_can_render_transformations_raccoon_composition(self, sdata_raccoon: SpatialData): + theta = math.pi / 6 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + scale = Scale([2.0], axes=("x",)) + + set_transformation(sdata_raccoon.images["raccoon"], scale, to_coordinate_system="global") + set_transformation(sdata_raccoon.shapes["circles"], scale, to_coordinate_system="global") + set_transformation(sdata_raccoon.labels["segmentation"], rotation, to_coordinate_system="global") + + sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show() + + def test_plot_can_render_transformations_raccoon_inverse(self, sdata_raccoon: SpatialData): + theta = math.pi / 6 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + scale = Scale([2.0], axes=("x",)) + sequence = Sequence([rotation, rotation.inverse(), scale, scale.inverse()]) + set_transformation(sdata_raccoon.images["raccoon"], sequence, to_coordinate_system="global") + + sdata_raccoon.pl.render_images().pl.render_labels().pl.render_shapes().pl.show() + def test_plot_can_render_blobs_images(sdata_blobs: SpatialData): sdata_blobs.pl.render_images().pl.show()