Skip to content

Commit

Permalink
RFC: refactor Plotable around typed data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Dec 5, 2024
1 parent f73e566 commit 606a7f1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 83 deletions.
112 changes: 61 additions & 51 deletions nonos/api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from collections import deque
from collections.abc import ItemsView, KeysView, ValuesView
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from shutil import copyfile
Expand All @@ -29,17 +30,43 @@
from matplotlib.figure import Figure


@dataclass(frozen=True, eq=False)
class NamedArray:
# TODO: use slots=True in @dataclass when Python 3.9 is dropped
__slots__ = ["name", "data"]
name: str
data: np.ndarray


@dataclass(frozen=True, eq=False)
class PlotableData:
# TODO: use slots=True in @dataclass when Python 3.9 is dropped
# (defining __slots__ manually isn't compatible with setting a default value)
abscissa: NamedArray
ordinate: NamedArray
field: Optional[NamedArray] = None


class Plotable:
def __init__(self, dict_plotable: dict) -> None:
self.dict_plotable = dict_plotable
self.data = self.dict_plotable[self.dict_plotable["field"]]
self.dimension = len(self.data.shape)
if self.dimension > 2:
def __init__(self, plotable_data: "PlotableData", /) -> None:
self.plotable_data = plotable_data
ndim = self.data.ndim
if ndim > 2:
raise TypeError(
"Plotable doesn't support data with dimensionality>2, "
f"got {self.dimension}"
f"Plotable doesn't support data with dimensionality>2, got {ndim}"
)

@property
def data(self) -> np.ndarray:
if self.plotable_data.field is not None:
arr = self.plotable_data.field.data
assert arr.ndim == 2
else:
arr = self.plotable_data.ordinate.data
assert arr.ndim == 1

return arr

def plot(
self,
fig: "Figure",
Expand All @@ -60,18 +87,20 @@ def plot(
"The nbin parameter has no effect and is deprecated",
stacklevel=2,
)

data = self.data
if unit_conversion is not None:
data = data * unit_conversion
if log:
data = np.log10(data)

akey = self.plotable_data.abscissa.name
aval = self.plotable_data.abscissa.data
okey = self.plotable_data.ordinate.name
oval = self.plotable_data.ordinate.data

artist: Artist
if self.dimension == 2:
self.akey = self.dict_plotable["abscissa"]
self.okey = self.dict_plotable["ordinate"]
self.avalue = self.dict_plotable[self.akey]
self.ovalue = self.dict_plotable[self.okey]
if data.ndim == 2:
kw = {}
if (norm := kwargs.get("norm")) is not None:
if "vmin" in kwargs:
Expand All @@ -83,21 +112,13 @@ def plot(
vmax = kwargs.pop("vmax") if "vmax" in kwargs else np.nanmax(data)
kw.update({"vmin": vmin, "vmax": vmax})

artist = im = ax.pcolormesh(
self.avalue,
self.ovalue,
data,
cmap=cmap,
**kwargs,
**kw,
)
artist = im = ax.pcolormesh(aval, oval, data, cmap=cmap, **kwargs, **kw)
ax.set(
xlim=(self.avalue.min(), self.avalue.max()),
ylim=(self.ovalue.min(), self.ovalue.max()),
xlim=(aval.min(), aval.max()),
ylim=(oval.min(), oval.max()),
xlabel=akey,
ylabel=okey,
)

ax.set_xlabel(self.akey)
ax.set_ylabel(self.okey)
if title is not None:
from mpl_toolkits.axes_grid1 import make_axes_locatable

Expand All @@ -122,24 +143,19 @@ def plot(
trf, subs=list(range(1, int(trf.base)))
)
cb_axis.set_minor_locator(locator)
elif self.dimension == 1:
elif data.ndim == 1:
vmin = kwargs.pop("vmin") if "vmin" in kwargs else np.nanmin(data)
vmax = kwargs.pop("vmax") if "vmax" in kwargs else np.nanmax(data)
self.akey = self.dict_plotable["abscissa"]
self.avalue = self.dict_plotable[self.akey]
if "norm" in kwargs:
logger.info("norm has no meaning in 1D.")
kwargs.pop("norm")
artist = ax.plot(self.avalue, data, **kwargs)[0]
ax.set_ylim(ymin=vmin)
ax.set_ylim(ymax=vmax)
ax.set_xlabel(self.akey)
artist = ax.plot(aval, data, **kwargs)[0]
ax.set(ylim=(vmin, vmax), xlabel=akey)
if title is not None:
ax.set_ylabel(title)
else:
raise TypeError(
"Plotable doesn't support data with dimensionality>2, "
f"got {self.dimension}"
f"Plotable doesn't support data with dimensionality>2, got {data.ndim}"
)
if filename is not None:
fig.savefig(f"{filename}.{fmt}", bbox_inches="tight", dpi=dpi)
Expand Down Expand Up @@ -546,12 +562,11 @@ def map(
else:
data_view = self.data.view()

dict_plotable = {
"abscissa": abscissa_key,
"field": data_key,
abscissa_key: abscissa_value,
data_key: data_view.squeeze(),
}
plotable_data = PlotableData(
abscissa=NamedArray(abscissa_key, abscissa_value),
ordinate=NamedArray(data_key, data_view.squeeze()),
)

elif dimension == 2:
# meshgrid in polar coordinates P, R (if "R", "phi") or R, P (if "phi", "R")
# idem for all combinations of R,phi,z
Expand Down Expand Up @@ -595,20 +610,15 @@ def rotate_axes(arr, shift: int):
if meshgrid_conversion["ordered"]:
data_view = data_view.T

dict_plotable = {
"abscissa": abscissa_key,
"ordinate": ordinate_key,
"field": data_key,
abscissa_key: abscissa_value,
ordinate_key: ordinate_value,
data_key: data_view,
}
plotable_data = PlotableData(
abscissa=NamedArray(abscissa_key, abscissa_value),
ordinate=NamedArray(ordinate_key, ordinate_value),
field=NamedArray(data_key, data_view),
)
else:
raise RuntimeError

assert dict_plotable[data_key].ndim == dimension

return Plotable(dict_plotable)
return Plotable(plotable_data)

def save(
self,
Expand Down
30 changes: 13 additions & 17 deletions nonos/api/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from lick.lick import lick_box

from nonos.api.analysis import Coordinates, GasField, Plotable
from nonos.api.analysis import Coordinates, GasField, NamedArray, Plotable, PlotableData
from nonos.loaders import Recipe, loader_from, recipe_from

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,23 +135,19 @@ def plot(
density_streamlines: Optional[float] = None,
color_streamlines: str = "black",
):
dict_background = {}
dict_background["field"] = "background"
dict_background["abscissa"] = "x"
dict_background["ordinate"] = "y"
dict_background[dict_background["field"]] = self.F
dict_background[dict_background["abscissa"]] = self.X
dict_background[dict_background["ordinate"]] = self.Y
background_data = PlotableData(
field=NamedArray("background", self.F),
abscissa=NamedArray("x", self.X),
ordinate=NamedArray("y", self.Y),
)

dict_lick = {}
dict_lick["field"] = "lick"
dict_lick["abscissa"] = "x"
dict_lick["ordinate"] = "y"
dict_lick[dict_lick["field"]] = self.lick
dict_lick[dict_lick["abscissa"]] = self.X
dict_lick[dict_lick["ordinate"]] = self.Y
foreground_data = PlotableData(
field=NamedArray("lick", self.lick),
abscissa=NamedArray("x", self.X),
ordinate=NamedArray("y", self.Y),
)

im = Plotable(dict_background).plot(
im = Plotable(background_data).plot(
fig,
ax,
vmin=vmin,
Expand All @@ -164,7 +160,7 @@ def plot(
shading="nearest",
rasterized=True,
)
Plotable(dict_lick).plot(
Plotable(foreground_data).plot(
fig,
ax,
log=False,
Expand Down
21 changes: 6 additions & 15 deletions nonos/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,19 @@ def process_field(
if "cmap" in plot_kwargs:
plot_kwargs.pop("cmap")

dsop.map(plane[0], rotate_with=planet_file).plot(fig, ax, **plot_kwargs)
akey = dsop.map(plane[0], rotate_with=planet_file).dict_plotable["abscissa"]
avalue = dsop.map(plane[0], rotate_with=planet_file).dict_plotable[akey]
plotable = dsop.map(plane[0], rotate_with=planet_file)
plotable.plot(fig, ax, **plot_kwargs)
avalue = plotable.plotable_data.abscissa.data
extent = parse_range(extent, dim=dim)
extent = range_converter(extent, abscissa=avalue, ordinate=np.zeros(2))
ax.set_xlim(extent[0], extent[1])
elif dim == 2:
dsop.map(plane[0], plane[1], rotate_with=planet_file).plot(
fig, ax, **plot_kwargs
)
akey = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[
"abscissa"
]
okey = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[
"ordinate"
]
avalue = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[
akey
]
ovalue = dsop.map(plane[0], plane[1], rotate_with=planet_file).dict_plotable[
okey
]
plot_data = dsop.map(plane[0], plane[1], rotate_with=planet_file).plotable_data
avalue = plot_data.abscissa.data
ovalue = plot_data.ordinate.data
extent = parse_range(extent, dim=dim)
extent = range_converter(
extent,
Expand Down

0 comments on commit 606a7f1

Please sign in to comment.