From 4a1d6c1483169a2f07576e19433fffffca50fbbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 5 Dec 2024 19:20:19 +0100 Subject: [PATCH] RFC: refactor Plotable around typed data structures --- nonos/api/analysis.py | 111 ++++++++++++++++++++++------------------- nonos/api/satellite.py | 30 +++++------ nonos/main.py | 21 +++----- 3 files changed, 79 insertions(+), 83 deletions(-) diff --git a/nonos/api/analysis.py b/nonos/api/analysis.py index 34aa0fd5..a02805b3 100644 --- a/nonos/api/analysis.py +++ b/nonos/api/analysis.py @@ -3,6 +3,7 @@ import warnings from collections import deque from collections.abc import ItemsView, KeysView, ValuesView +from dataclasses import dataclass from functools import cached_property from pathlib import Path from shutil import copyfile @@ -29,17 +30,42 @@ from matplotlib.figure import Figure +@dataclass(frozen=True, eq=False) +class NamedArray: + # TODO: use slots=True in @dataclass when Python 3.9 is dropped + __slots__ = ["name", "data"] + name: str + data: np.ndarray + + +@dataclass(frozen=True, eq=False) +class PlotableData: + # TODO: use slots=True in @dataclass when Python 3.9 is dropped + # (defining __slots__ manually isn't compatible with setting a default value) + abscissa: NamedArray + ordinate: NamedArray + field: Optional[NamedArray] = None + + class Plotable: - def __init__(self, dict_plotable: dict) -> None: - self.dict_plotable = dict_plotable - self.data = self.dict_plotable[self.dict_plotable["field"]] - self.dimension = len(self.data.shape) - if self.dimension > 2: + def __init__(self, plotable_data: "PlotableData", /) -> None: + self.plotable_data = plotable_data + if ndim := self.data.ndim > 2: raise TypeError( - "Plotable doesn't support data with dimensionality>2, " - f"got {self.dimension}" + f"Plotable doesn't support data with dimensionality>2, got {ndim}" ) + @property + def data(self) -> np.ndarray: + if self.plotable_data.field is not None: + arr = self.plotable_data.field.data + assert arr.ndim == 2 + else: + arr = self.plotable_data.ordinate.data + assert arr.ndim == 1 + + return arr + def plot( self, fig: "Figure", @@ -60,18 +86,20 @@ def plot( "The nbin parameter has no effect and is deprecated", stacklevel=2, ) + data = self.data if unit_conversion is not None: data = data * unit_conversion if log: data = np.log10(data) + akey = self.plotable_data.abscissa.name + aval = self.plotable_data.abscissa.data + okey = self.plotable_data.ordinate.name + oval = self.plotable_data.ordinate.data + artist: Artist - if self.dimension == 2: - self.akey = self.dict_plotable["abscissa"] - self.okey = self.dict_plotable["ordinate"] - self.avalue = self.dict_plotable[self.akey] - self.ovalue = self.dict_plotable[self.okey] + if data.ndim == 2: kw = {} if (norm := kwargs.get("norm")) is not None: if "vmin" in kwargs: @@ -83,21 +111,13 @@ def plot( vmax = kwargs.pop("vmax") if "vmax" in kwargs else np.nanmax(data) kw.update({"vmin": vmin, "vmax": vmax}) - artist = im = ax.pcolormesh( - self.avalue, - self.ovalue, - data, - cmap=cmap, - **kwargs, - **kw, - ) + artist = im = ax.pcolormesh(aval, oval, data, cmap=cmap, **kwargs, **kw) ax.set( - xlim=(self.avalue.min(), self.avalue.max()), - ylim=(self.ovalue.min(), self.ovalue.max()), + xlim=(aval.min(), aval.max()), + ylim=(oval.min(), oval.max()), + xlabel=akey, + ylabel=okey, ) - - ax.set_xlabel(self.akey) - ax.set_ylabel(self.okey) if title is not None: from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -122,24 +142,19 @@ def plot( trf, subs=list(range(1, int(trf.base))) ) cb_axis.set_minor_locator(locator) - elif self.dimension == 1: + elif data.ndim == 1: vmin = kwargs.pop("vmin") if "vmin" in kwargs else np.nanmin(data) vmax = kwargs.pop("vmax") if "vmax" in kwargs else np.nanmax(data) - self.akey = self.dict_plotable["abscissa"] - self.avalue = self.dict_plotable[self.akey] if "norm" in kwargs: logger.info("norm has no meaning in 1D.") kwargs.pop("norm") - artist = ax.plot(self.avalue, data, **kwargs)[0] - ax.set_ylim(ymin=vmin) - ax.set_ylim(ymax=vmax) - ax.set_xlabel(self.akey) + artist = ax.plot(aval, data, **kwargs)[0] + ax.set(ylim=(vmin, vmax), xlabel=akey) if title is not None: ax.set_ylabel(title) else: raise TypeError( - "Plotable doesn't support data with dimensionality>2, " - f"got {self.dimension}" + f"Plotable doesn't support data with dimensionality>2, got {data.ndim}" ) if filename is not None: fig.savefig(f"{filename}.{fmt}", bbox_inches="tight", dpi=dpi) @@ -546,12 +561,11 @@ def map( else: data_view = self.data.view() - dict_plotable = { - "abscissa": abscissa_key, - "field": data_key, - abscissa_key: abscissa_value, - data_key: data_view.squeeze(), - } + plotable_data = PlotableData( + abscissa=NamedArray(abscissa_key, abscissa_value), + ordinate=NamedArray(data_key, data_view.squeeze()), + ) + elif dimension == 2: # meshgrid in polar coordinates P, R (if "R", "phi") or R, P (if "phi", "R") # idem for all combinations of R,phi,z @@ -595,20 +609,15 @@ def rotate_axes(arr, shift: int): if meshgrid_conversion["ordered"]: data_view = data_view.T - dict_plotable = { - "abscissa": abscissa_key, - "ordinate": ordinate_key, - "field": data_key, - abscissa_key: abscissa_value, - ordinate_key: ordinate_value, - data_key: data_view, - } + plotable_data = PlotableData( + abscissa=NamedArray(abscissa_key, abscissa_value), + ordinate=NamedArray(ordinate_key, ordinate_value), + field=NamedArray(data_key, data_view), + ) else: raise RuntimeError - assert dict_plotable[data_key].ndim == dimension - - return Plotable(dict_plotable) + return Plotable(plotable_data) def save( self, diff --git a/nonos/api/satellite.py b/nonos/api/satellite.py index ce46454c..66d84f97 100644 --- a/nonos/api/satellite.py +++ b/nonos/api/satellite.py @@ -5,7 +5,7 @@ import numpy as np from lick.lick import lick_box -from nonos.api.analysis import Coordinates, GasField, Plotable +from nonos.api.analysis import Coordinates, GasField, NamedArray, Plotable, PlotableData from nonos.loaders import Recipe, loader_from, recipe_from if TYPE_CHECKING: @@ -135,23 +135,19 @@ def plot( density_streamlines: Optional[float] = None, color_streamlines: str = "black", ): - dict_background = {} - dict_background["field"] = "background" - dict_background["abscissa"] = "x" - dict_background["ordinate"] = "y" - dict_background[dict_background["field"]] = self.F - dict_background[dict_background["abscissa"]] = self.X - dict_background[dict_background["ordinate"]] = self.Y + background_data = PlotableData( + field=NamedArray("background", self.F), + abscissa=NamedArray("x", self.X), + ordinate=NamedArray("y", self.Y), + ) - dict_lick = {} - dict_lick["field"] = "lick" - dict_lick["abscissa"] = "x" - dict_lick["ordinate"] = "y" - dict_lick[dict_lick["field"]] = self.lick - dict_lick[dict_lick["abscissa"]] = self.X - dict_lick[dict_lick["ordinate"]] = self.Y + foreground_data = PlotableData( + field=NamedArray("lick", self.lick), + abscissa=NamedArray("x", self.X), + ordinate=NamedArray("y", self.Y), + ) - im = Plotable(dict_background).plot( + im = Plotable(background_data).plot( fig, ax, vmin=vmin, @@ -164,7 +160,7 @@ def plot( shading="nearest", rasterized=True, ) - Plotable(dict_lick).plot( + Plotable(foreground_data).plot( fig, ax, log=False, diff --git a/nonos/main.py b/nonos/main.py index 5cb03543..a1575c1a 100644 --- a/nonos/main.py +++ b/nonos/main.py @@ -195,9 +195,9 @@ def process_field( if "cmap" in plot_kwargs: plot_kwargs.pop("cmap") - dsop.map(plane[0], rotate_with=planet_file).plot(fig, ax, **plot_kwargs) - akey = dsop.map(plane[0], rotate_with=planet_file).dict_plotable["abscissa"] - avalue = dsop.map(plane[0], rotate_with=planet_file).dict_plotable[akey] + plotable = dsop.map(plane[0], rotate_with=planet_file) + plotable.plot(fig, ax, **plot_kwargs) + avalue = plotable.plotable_data.abscissa.data extent = parse_range(extent, dim=dim) extent = range_converter(extent, abscissa=avalue, ordinate=np.zeros(2)) ax.set_xlim(extent[0], extent[1]) @@ -205,18 +205,9 @@ def process_field( dsop.map(plane[0], plane[1], rotate_with=planet_file).plot( fig, ax, **plot_kwargs ) - akey = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[ - "abscissa" - ] - okey = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[ - "ordinate" - ] - avalue = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[ - akey - ] - ovalue = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[ - okey - ] + plot_data = dsop.map(plane[0], plane[1], rotate_with=planet_file).plotable_data + avalue = plot_data.abscissa.data + ovalue = plot_data.ordinate.data extent = parse_range(extent, dim=dim) extent = range_converter( extent,