From fbaf844aead3e91c505ebca62119611f4f9c1596 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 6 Dec 2024 09:24:04 +0100 Subject: [PATCH] fixup! fixup! RFC: refactor Plotable around typed data structures --- nonos/api/analysis.py | 22 +++++++++++----------- nonos/api/satellite.py | 17 +++++++---------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/nonos/api/analysis.py b/nonos/api/analysis.py index b8b66088..ba4d2fa8 100644 --- a/nonos/api/analysis.py +++ b/nonos/api/analysis.py @@ -44,13 +44,13 @@ class Plotable: def __init__( self, *, - abscissa: NamedArray, - ordinate: NamedArray, - field: Optional[NamedArray] = None, + abscissa: tuple[str, np.ndarray], + ordinate: tuple[str, np.ndarray], + field: Optional[tuple[str, np.ndarray]] = None, ) -> None: - self.abscissa = abscissa - self.ordinate = ordinate - self.field = field + self.abscissa = NamedArray(*abscissa) + self.ordinate = NamedArray(*ordinate) + self.field = None if field is None else NamedArray(*field) if ndim := self.data.ndim > 2: raise TypeError( f"Plotable doesn't support data with dimensionality>2, got {ndim}" @@ -563,8 +563,8 @@ def map( data_view = self.data.view() return Plotable( - abscissa=NamedArray(abscissa_key, abscissa_value), - ordinate=NamedArray(data_key, data_view.squeeze()), + abscissa=(abscissa_key, abscissa_value), + ordinate=(data_key, data_view.squeeze()), ) elif dimension == 2: @@ -611,9 +611,9 @@ def rotate_axes(arr, shift: int): data_view = data_view.T return Plotable( - abscissa=NamedArray(abscissa_key, abscissa_value), - ordinate=NamedArray(ordinate_key, ordinate_value), - field=NamedArray(data_key, data_view), + abscissa=(abscissa_key, abscissa_value), + ordinate=(ordinate_key, ordinate_value), + field=(data_key, data_view), ) else: raise RuntimeError diff --git a/nonos/api/satellite.py b/nonos/api/satellite.py index 7294a08e..e0ac6de6 100644 --- a/nonos/api/satellite.py +++ b/nonos/api/satellite.py @@ -5,7 +5,7 @@ import numpy as np from lick.lick import lick_box -from nonos.api.analysis import Coordinates, GasField, NamedArray, Plotable +from nonos.api.analysis import Coordinates, GasField, Plotable from nonos.loaders import Recipe, loader_from, recipe_from if TYPE_CHECKING: @@ -136,9 +136,9 @@ def plot( color_streamlines: str = "black", ): im = Plotable( - field=NamedArray("background", self.F), - abscissa=NamedArray("x", self.X), - ordinate=NamedArray("y", self.Y), + abscissa=("x", self.X), + ordinate=("y", self.Y), + field=("background", self.F), ).plot( fig, ax, @@ -146,24 +146,21 @@ def plot( vmax=vmax, log=log, cmap=cmap, - filename=None, dpi=500, title=title, shading="nearest", rasterized=True, ) Plotable( - field=NamedArray("lick", self.lick), - abscissa=NamedArray("x", self.X), - ordinate=NamedArray("y", self.Y), + abscissa=("x", self.X), + ordinate=("y", self.Y), + field=("lick", self.lick), ).plot( fig, ax, log=False, cmap="binary_r", - filename=None, dpi=500, - title=None, alpha=alpha, shading="nearest", rasterized=True,