Skip to content

Commit

Permalink
bug-fix: save() all field cube (when no operation)
Browse files Browse the repository at this point in the history
  • Loading branch information
volodia99 committed Feb 18, 2024
1 parent 8692f09 commit b93b53e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
9 changes: 5 additions & 4 deletions nonos/api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,14 @@ def map(
return Plotable(dict_plotable)

def save(self, directory="", header_only=False) -> None:
operation = "_" if self.operation=="" else self.operation
if not header_only:
if not os.path.exists(os.path.join(directory, self.field.lower())):
os.makedirs(os.path.join(directory, self.field.lower()))
filename = os.path.join(
directory,
self.field.lower(),
f"{self.operation}_{self.field}.{self.on:04d}.npy",
f"{operation}_{self.field}.{self.on:04d}.npy",
)
if Path(filename).is_file():
logger.info("{} already exists", filename)
Expand All @@ -616,19 +617,19 @@ def save(self, directory="", header_only=False) -> None:

group_of_files = list(
glob.glob1(
os.path.join(directory, self.field.lower()), f"{self.operation}*"
os.path.join(directory, self.field.lower()), f"{operation}*"
)
)
header_file = list(
glob.glob1(
os.path.join(directory, "header"), f"header{self.operation}.json"
os.path.join(directory, "header"), f"header{operation}.json"
)
)
if (len(group_of_files) > 0 and len(header_file) == 0) or header_only:
if not os.path.exists(os.path.join(directory, "header")):
os.makedirs(os.path.join(directory, "header"))
headername = os.path.join(
directory, "header", f"header{self.operation}.json"
directory, "header", f"header{operation}.json"
)
if Path(headername).is_file():
logger.info("{} already exists", headername)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from shutil import copytree

import pytest
import numpy as np

from nonos.api import GasDataSet

Expand Down Expand Up @@ -29,6 +30,21 @@ def test_roundtrip_simple(test_data_dir, tmp_path):
assert len(list(dsnpy.keys())) == 1


def test_roundtrip_no_operation_all_field(test_data_dir, tmp_path):
copytree(test_data_dir / "idefix_spherical_planet3d", tmp_path / "mydir")

os.chdir(tmp_path / "mydir")
ds = GasDataSet(500)
assert len(list(ds.keys())) == 7

gf = ds["RHO"]

gf.save()
dsnpy = GasDataSet.from_npy(500, operation="")
assert len(list(dsnpy.keys())) == 1
np.testing.assert_array_almost_equal(ds["RHO"].data,dsnpy["RHO"].data)


def test_roundtrip_other_dir(test_data_dir, tmp_path):
os.chdir(test_data_dir / "idefix_spherical_planet3d")
gf = GasDataSet(500)["RHO"].azimuthal_average()
Expand Down

0 comments on commit b93b53e

Please sign in to comment.