Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored transformations implementation #162

Merged
merged 8 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning][].

## [0.1.0] - tbd

### Added

- Pushed `get_extent` functionality upstream to `spatialdata` (#162)

### Fixed

-

## [0.0.5] - 2023-10-02

### Added
Expand Down
190 changes: 93 additions & 97 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from pandas.api.types import is_categorical_dtype
from spatial_image import SpatialImage
from spatialdata._logging import logger as logg
from spatialdata._core.data_extent import get_extent
from spatialdata.transformations.operations import get_transformation

from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot.pl.render import (
Expand All @@ -40,12 +41,10 @@
)
from spatialdata_plot.pl.utils import (
_get_cs_contents,
_get_extent,
_maybe_set_colors,
_mpl_ax_contains_elements,
_prepare_cmap_norm,
_prepare_params_plot,
_robust_transform,
_set_outline,
save_fig,
)
Expand Down Expand Up @@ -216,6 +215,8 @@ def render_shapes(
na_color=na_color, # type: ignore[arg-type]
**kwargs,
)
if isinstance(elements, str):
elements = [elements]
outline_params = _set_outline(outline, outline_width, outline_color)
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
elements=elements,
Expand Down Expand Up @@ -285,12 +286,15 @@ def render_points(
sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())

cmap_params = _prepare_cmap_norm(
cmap=cmap,
norm=norm,
na_color=na_color, # type: ignore[arg-type]
**kwargs,
)
if isinstance(elements, str):
elements = [elements]
sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
elements=elements,
color=color,
Expand Down Expand Up @@ -370,6 +374,8 @@ def render_images(
**kwargs,
)

if isinstance(elements, str):
elements = [elements]
sdata.plotting_tree[f"{n_steps+1}_render_images"] = ImageRenderParams(
elements=elements,
channel=channel,
Expand Down Expand Up @@ -450,6 +456,8 @@ def render_labels(
na_color=na_color, # type: ignore[arg-type]
**kwargs,
)
if isinstance(elements, str):
elements = [elements]
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
elements=elements,
color=color,
Expand Down Expand Up @@ -552,12 +560,12 @@ def show(
raise TypeError("All titles must be strings.")

# get original axis extent for later comparison
x_min_orig, x_max_orig = (np.inf, -np.inf)
y_min_orig, y_max_orig = (np.inf, -np.inf)
ax_x_min, ax_x_max = (np.inf, -np.inf)
ax_y_min, ax_y_max = (np.inf, -np.inf)

if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax):
x_min_orig, x_max_orig = ax.get_xlim()
y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left
ax_x_min, ax_x_max = ax.get_xlim()
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left

# handle coordinate system
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
Expand All @@ -568,50 +576,6 @@ def show(
if cs not in sdata.coordinate_systems:
raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}")

# Check if user specified only certain elements to be plotted
cs_contents = _get_cs_contents(sdata)
elements_to_be_rendered = []
for cmd, params in render_cmds.items():
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)

extent = _get_extent(
sdata=sdata,
has_images="render_images" in render_cmds,
has_labels="render_labels" in render_cmds,
has_points="render_points" in render_cmds,
has_shapes="render_shapes" in render_cmds,
elements=elements_to_be_rendered,
coordinate_systems=coordinate_systems,
)

# Use extent to filter out coordinate system without the relevant elements
valid_cs = []
for cs in coordinate_systems:
if cs in extent:
valid_cs.append(cs)
else:
logg.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
coordinate_systems = valid_cs

# set up canvas
fig_params, scalebar_params = _prepare_params_plot(
num_panels=len(coordinate_systems),
Expand All @@ -633,32 +597,25 @@ def show(
colorbar=colorbar,
)

cs_contents = _get_cs_contents(sdata)

# go through tree

for i, cs in enumerate(coordinate_systems):
sdata = self._copy()
# properly transform all elements to the current coordinate system
members = cs_contents.query(f"cs == '{cs}'")

if members["has_images"].values[0]:
for key in sdata.images:
sdata.images[key] = _robust_transform(sdata.images[key], cs)

if members["has_labels"].values[0]:
for key in sdata.labels:
sdata.labels[key] = _robust_transform(sdata.labels[key], cs)

if members["has_points"].values[0]:
for key in sdata.points:
sdata.points[key] = _robust_transform(sdata.points[key], cs)

if members["has_shapes"].values[0]:
for key in sdata.shapes:
sdata.shapes[key] = _robust_transform(sdata.shapes[key], cs)

_, has_images, has_labels, has_points, has_shapes = (
cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist()
)
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]

wants_images = False
wants_labels = False
wants_points = False
wants_shapes = False
wanted_elements = []

for cmd, params in render_cmds.items():
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]:
if cmd == "render_images" and has_images:
_render_images(
sdata=sdata,
render_params=params,
Expand All @@ -667,9 +624,18 @@ def show(
fig_params=fig_params,
scalebar_params=scalebar_params,
legend_params=legend_params,
# extent=extent[cs],
)
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]:
wants_images = True
wanted_images = params.elements if params.elements is not None else list(sdata.images.keys())
wanted_elements.extend(
[
image
for image in wanted_images
if cs in set(get_transformation(sdata.images[image], get_all=True).keys())
]
)

elif cmd == "render_shapes" and has_shapes:
_render_shapes(
sdata=sdata,
render_params=params,
Expand All @@ -679,8 +645,17 @@ def show(
scalebar_params=scalebar_params,
legend_params=legend_params,
)
wants_shapes = True
wanted_shapes = params.elements if params.elements is not None else list(sdata.shapes.keys())
wanted_elements.extend(
[
shape
for shape in wanted_shapes
if cs in set(get_transformation(sdata.shapes[shape], get_all=True).keys())
]
)

elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]:
elif cmd == "render_points" and has_points:
_render_points(
sdata=sdata,
render_params=params,
Expand All @@ -690,8 +665,17 @@ def show(
scalebar_params=scalebar_params,
legend_params=legend_params,
)
wants_points = True
wanted_points = params.elements if params.elements is not None else list(sdata.points.keys())
wanted_elements.extend(
[
point
for point in wanted_points
if cs in set(get_transformation(sdata.points[point], get_all=True).keys())
]
)

elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]:
elif cmd == "render_labels" and has_labels:
if sdata.table is not None and isinstance(params.color, str):
colors = sc.get.obs_df(sdata.table, params.color)
if is_categorical_dtype(colors):
Expand All @@ -710,33 +694,46 @@ def show(
scalebar_params=scalebar_params,
legend_params=legend_params,
)
wants_labels = True
wanted_labels = params.elements if params.elements is not None else list(sdata.labels.keys())
wanted_elements.extend(
[
label
for label in wanted_labels
if cs in set(get_transformation(sdata.labels[label], get_all=True).keys())
]
)

if title is not None:
if len(title) == 1:
t = title[0]
else:
try:
t = title[i]
except IndexError as e:
raise IndexError("The number of titles must match the number of coordinate systems.") from e
else:
if title is None:
t = cs
elif len(title) == 1:
t = title[0]
else:
try:
t = title[i]
except IndexError as e:
raise IndexError("The number of titles must match the number of coordinate systems.") from e
ax.set_title(t)
ax.set_aspect("equal")

if any(
[
cs_contents.query(f"cs == '{cs}'")["has_images"][0],
cs_contents.query(f"cs == '{cs}'")["has_labels"][0],
cs_contents.query(f"cs == '{cs}'")["has_points"][0],
cs_contents.query(f"cs == '{cs}'")["has_shapes"][0],
]
):
extent = get_extent(
sdata,
coordinate_system=cs,
has_images=has_images and wants_images,
has_labels=has_labels and wants_labels,
has_points=has_points and wants_points,
has_shapes=has_shapes and wants_shapes,
elements=wanted_elements,
)
cs_x_min, cs_x_max = extent["x"]
cs_y_min, cs_y_max = extent["y"]

if any([has_images, has_labels, has_points, has_shapes]):
# If the axis already has limits, only expand them but not overwrite
x_min = min(x_min_orig, extent[cs][0]) - pad_extent
x_max = max(x_max_orig, extent[cs][1]) + pad_extent
y_min = min(y_min_orig, extent[cs][2]) - pad_extent
y_max = max(y_max_orig, extent[cs][3]) + pad_extent
x_min = min(ax_x_min, cs_x_min) - pad_extent
x_max = max(ax_x_max, cs_x_max) + pad_extent
y_min = min(ax_y_min, cs_y_min) - pad_extent
y_max = max(ax_y_max, cs_y_max) + pad_extent
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_max, y_min) # (0, 0) is top-left

Expand All @@ -747,5 +744,4 @@ def show(
# https://stackoverflow.com/a/64523765
if not hasattr(sys, "ps1"):
plt.show()

return (fig_params.ax if fig_params.axs is None else fig_params.axs) if return_ax else None # shuts up ruff
Loading