diff --git a/nonos/api/analysis.py b/nonos/api/analysis.py index b80b2211..13875464 100644 --- a/nonos/api/analysis.py +++ b/nonos/api/analysis.py @@ -1,6 +1,7 @@ import dataclasses import json import warnings +from collections import deque from collections.abc import ItemsView, KeysView, ValuesView from functools import cached_property from pathlib import Path @@ -543,15 +544,11 @@ def map( else: data_view = self.data.view() - datamoved_tmp = np.moveaxis(data_view, self.shape.index(1), 0) - datamoved = np.moveaxis( - datamoved_tmp[0], datamoved_tmp[0].shape.index(1), 0 - ) dict_plotable = { "abscissa": abscissa_key, "field": data_key, abscissa_key: abscissa_value, - data_key: datamoved[0], + data_key: data_view.squeeze(), } elif dimension == 2: # meshgrid in polar coordinates P, R (if "R", "phi") or R, P (if "phi", "R") @@ -583,21 +580,18 @@ def map( else: data_view = self.data.view() - ordered = meshgrid_conversion["ordered"] - # move the axis of reduction in the front in order to - # perform the operation 3D(i,j,1) -> 2D(i,j) in a general way, - # whatever the position of (i,j,1) - # for that we must be careful to change the order ("ordered") if the data is reversed - # in practice, this is tricky only if the "1" is in the middle: - # 3D(i,1,k) -> 2D(i,k) is not a direct triedre anymore, we need to do 2D(i,k).T = 2D(k,i) - position_of_3d_dimension = self.shape.index(1) - datamoved = np.moveaxis(data_view, position_of_3d_dimension, 0) - if position_of_3d_dimension == 1: - ordered = not ordered - if ordered: - data_value = datamoved[0].T - else: - data_value = datamoved[0] + def rotate_axes(arr, shift: int): + axes_in = tuple(range(arr.ndim)) + axes_out = deque(axes_in) + axes_out.rotate(shift) + return np.moveaxis(arr, axes_in, axes_out) + + # make reduction axis the first axis then drop (squeeze) it, + # while preserving the original (cyclic) order in the other two axes + data_view = rotate_axes(data_view, shift=self.shape.index(1)).squeeze() + + if meshgrid_conversion["ordered"]: + data_view = data_view.T dict_plotable = { "abscissa": abscissa_key, @@ -605,7 +599,7 @@ def map( "field": data_key, abscissa_key: abscissa_value, ordinate_key: ordinate_value, - data_key: data_value, + data_key: data_view, } else: raise RuntimeError