Skip to content

Commit

Permalink
Merge pull request #336 from neutrinoceros/bug/avoid_mutating_state_i…
Browse files Browse the repository at this point in the history
…n_map

BUG: avoid mutating `GasField`'s state in `GasField.map` when rotation is requested
  • Loading branch information
neutrinoceros authored May 6, 2024
2 parents 194324e + cf31172 commit 43aded1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
19 changes: 11 additions & 8 deletions nonos/api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,16 +533,17 @@ def map(
):
ipi = find_nearest(phicoord, 2 * np.pi)
if self.native_geometry == "polar":
self.data = np.roll(self.data, -ipi + 1, axis=1)
data_view = np.roll(self.data, -ipi + 1, axis=1)
elif self.native_geometry == "spherical":
self.data = np.roll(self.data, -ipi + 1, axis=2)
data_view = np.roll(self.data, -ipi + 1, axis=2)
else:
raise NotImplementedError(
f"geometry flag '{self.native_geometry}' not implemented yet if corotation"
)
self._rotate_by = rotate_by
else:
data_view = self.data.view()

datamoved_tmp = np.moveaxis(self.data, self.shape.index(1), 0)
datamoved_tmp = np.moveaxis(data_view, self.shape.index(1), 0)
datamoved = np.moveaxis(
datamoved_tmp[0], datamoved_tmp[0].shape.index(1), 0
)
Expand Down Expand Up @@ -572,14 +573,16 @@ def map(
):
ipi = find_nearest(phicoord, 2 * np.pi)
if self.native_geometry == "polar":
self.data = np.roll(self.data, -ipi + 1, axis=1)
data_view = np.roll(self.data, -ipi + 1, axis=1)
elif self.native_geometry == "spherical":
self.data = np.roll(self.data, -ipi + 1, axis=2)
data_view = np.roll(self.data, -ipi + 1, axis=2)
else:
raise NotImplementedError(
f"geometry flag '{self.native_geometry}' not implemented yet if corotation"
)
self._rotate_by = rotate_by
else:
data_view = self.data.view()

ordered = meshgrid_conversion["ordered"]
# move the axis of reduction in the front in order to
# perform the operation 3D(i,j,1) -> 2D(i,j) in a general way,
Expand All @@ -588,7 +591,7 @@ def map(
# in practice, this is tricky only if the "1" is in the middle:
# 3D(i,1,k) -> 2D(i,k) is not a direct triedre anymore, we need to do 2D(i,k).T = 2D(k,i)
position_of_3d_dimension = self.shape.index(1)
datamoved = np.moveaxis(self.data, position_of_3d_dimension, 0)
datamoved = np.moveaxis(data_view, position_of_3d_dimension, 0)
if position_of_3d_dimension == 1:
ordered = not ordered
if ordered:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
import numpy.testing as npt
import pytest

from nonos.api import GasDataSet, file_analysis
Expand Down Expand Up @@ -74,3 +75,12 @@ def test_find_rhill(test_data_dir):
rp = ds["RHO"].find_rp()
rhill = ds["RHO"].find_rhill()
assert rhill < rp


def test_field_map_no_mutation(test_data_dir):
ds = GasDataSet(500, directory=test_data_dir / "idefix_spherical_planet3d")
f = ds["RHO"].radial_at_r(1.0).vertical_at_midplane()
d0 = f.data.copy()
f.map("phi", rotate_by=1.0)
d1 = f.data.copy()
npt.assert_array_equal(d1, d0)

0 comments on commit 43aded1

Please sign in to comment.