Skip to content

Commit

Permalink
Merge pull request #330 from neutrinoceros/npy/bug/read_without_density
Browse files Browse the repository at this point in the history
TST: fix 3 bugs in `NPYReader` related to reading from directories which lack density files
  • Loading branch information
volodia99 authored Apr 8, 2024
2 parents aae1266 + 1fb4e39 commit cf641b7
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 23 deletions.
44 changes: 28 additions & 16 deletions nonos/_readers/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ class NPYReader(ReaderMixin):
# we accept a leading '_' for backward compatibility
_filename_re = re.compile(
r"^_?(?P<prefix>[\w\.]*)"
r"_(?P<field_name>[A-Z]+)"
r"_(?P<field_name>[A-Z\d]+)"
r"\.(?P<output_number>\d+)"
r"\.npy$"
)
Expand All @@ -663,32 +663,46 @@ def parse_output_number_and_filename(
directory = Path(directory).resolve()
if isinstance(file_or_number, (str, Path)):
file = Path(file_or_number)
if file == Path(file.name):
file = directory / "rho" / file

if (match := NPYReader._filename_re.fullmatch(file.name)) is None:
raise ValueError(f"Filename {file.name!r} is not recognized")
if file == Path(file.name):
file = directory / match.group("field_name").lower() / file
output_number = int(match.group("output_number"))
file_alt = None
else:
output_number = file_or_number
file = directory / "rho" / f"{prefix}_RHO.{output_number:04d}.npy"
file_alt = file.with_name(f"_{file.name}")
all_bin_files = NPYReader.get_bin_files(directory / "any")
_filter_re = re.compile(rf"^_?{prefix}_[A-Z\d]+.{output_number:04d}.npy")
matches = [
file for file in all_bin_files if _filter_re.fullmatch(file.name)
]
if not matches:
raise FileNotFoundError(
"Failed to locate a file matching "
f"{prefix=!r} and {output_number=}"
)
file = matches[0]
file_alt = (
None if file.name.startswith("_") else file.with_name(f"_{file.name}")
)

if not file.is_file():
if file_alt is not None and file_alt.is_file():
# backward compatibility
file = file_alt
else:
raise FileNotFoundError(file)
if file_alt is None:
raise FileNotFoundError(str(file))

if not file_alt.is_file():
raise FileNotFoundError(f"{file} (also tried {file_alt})")
# backward compatibility
file = file_alt

return output_number, file

@staticmethod
def get_bin_files(directory: PathT, /) -> list[Path]:
# return *all* loadable files
# (not just the ones matching a particular prefix)
directory = Path(directory).resolve()
density_file_paths: list[Path] = []
file_paths: list[Path] = []
for subdir in directory.parent.glob("*"):
if not subdir.is_dir():
continue
Expand All @@ -699,11 +713,9 @@ def get_bin_files(directory: PathT, /) -> list[Path]:
continue
if NPYReader._filename_re.fullmatch(file.name) is None:
continue
density_file_paths.append(file)
file_paths.append(file)

if not density_file_paths: # pragma: no cover
raise RuntimeError
return sorted(density_file_paths)
return sorted(file_paths)

@staticmethod
def read(file: PathT, /, **meta) -> BinData:
Expand Down
42 changes: 35 additions & 7 deletions tests/test_interfaces/test_binary_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def initdir(self, tmp_path):

with open(header_dir / "header_azimuthal_average.json", "w") as fh:
json.dump(fake_grid, fh)
with open(header_dir / "header_vertical_at_midplane.json", "w") as fh:
json.dump(fake_grid, fh)

rng = np.random.default_rng(seed=0)

Expand All @@ -78,6 +80,13 @@ def initdir(self, tmp_path):
with open(foo_file, "wb") as fb:
np.save(fb, fake_foo)

vx1_dir = tmp_path / "vx1"
vx1_dir.mkdir()
vx1_file = vx1_dir.joinpath("_vertical_at_midplane_VX1.1200.npy")
fake_vx1 = rng.normal(0, 1, size=8).resize(2, 2, 2)
with open(vx1_file, "wb") as fb:
np.save(fb, fake_vx1)

# sprinkle some "traps" too (files and dirs that *shouldn't* match)
tmp_path.joinpath("trap1.0001").mkdir()
foo_dir.joinpath("trap2.json").touch()
Expand All @@ -87,7 +96,7 @@ def initdir(self, tmp_path):

rho_dir.joinpath("trap4.npy").touch()

real_files = sorted([foo_file, rho_file, rho_file_rr, rho_file_2])
real_files = sorted([foo_file, rho_file, rho_file_rr, rho_file_2, vx1_file])
return tmp_path, real_files

def test_get_bin_files(self, initdir):
Expand All @@ -103,18 +112,37 @@ def test_get_bin_files(self, initdir):
assert results == expected_files

@pytest.mark.parametrize(
"file_or_number",
["azimuthal_average_RHO.0001.npy", 1],
"file_or_number, prefix, expected_output_number, expected_filename",
[
(
"azimuthal_average_RHO.0001.npy",
"azimuthal_average",
1,
("rho", "azimuthal_average_RHO.0001.npy"),
),
(1, "azimuthal_average", 1, ("rho", "azimuthal_average_RHO.0001.npy")),
("__FOO.0456.npy", "", 456, ("foo", "__FOO.0456.npy")),
(456, "", 456, ("foo", "__FOO.0456.npy")),
],
)
def test_parse_output_number_and_filename(self, initdir, file_or_number):
def test_parse_output_number_and_filename(
self,
initdir,
file_or_number,
prefix,
expected_output_number,
expected_filename,
):
tmp_path, _files = initdir
directory = tmp_path.resolve()

output_number, filename = NPYReader.parse_output_number_and_filename(
file_or_number, directory=directory, prefix="azimuthal_average"
file_or_number,
directory=directory,
prefix=prefix,
)
assert output_number == 1
assert filename == directory / "rho" / "azimuthal_average_RHO.0001.npy"
assert output_number == expected_output_number
assert filename == directory.joinpath(*expected_filename)

def test_parse_output_number_and_filename_invalid_file(self, initdir):
tmp_path, _files = initdir
Expand Down

0 comments on commit cf641b7

Please sign in to comment.