Skip to content

Commit

Permalink
#90 Merge pull request from deshima-dev/astropenguin/issue89
Browse files Browse the repository at this point in the history
Rename io module (io → load)
  • Loading branch information
astropenguin authored Oct 25, 2023
2 parents 5576f70 + 097503b commit 19fac04
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 23 deletions.
4 changes: 2 additions & 2 deletions decode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__all__ = ["io", "plot", "select"]
__all__ = ["load", "plot", "select"]
__version__ = "2.1.0"


# submodules
from . import io
from . import load
from . import plot
from . import select
26 changes: 17 additions & 9 deletions decode/io.py → decode/load.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["open_dems"]
__all__ = ["dems"]


# standard library
Expand All @@ -17,31 +17,39 @@
ZARR_SUFFIX = ".zarr"


def open_dems(dems: Union[Path, str], **kwargs: Any) -> xr.DataArray:
"""Open a DEMS file as a DataArray.
def dems(dems: Union[Path, str], /, **options: Any) -> xr.DataArray:
"""Load a DEMS file as a DataArray.
Args:
dems: Path of the DEMS file.
kwargs: Arguments to be passed to ``xarray.open_dataarray``.
Keyword Args:
options: Arguments to be passed to ``xarray.open_dataarray``.
Return:
A DataArray of the opened DEMS file.
Loaded DEMS DataArray.
Raises:
ValueError: Raised if the file type is not supported.
"""
engine: str
suffixes = Path(dems).suffixes

if NETCDF_SUFFIX in suffixes:
engine = kwargs.pop("engine", NETCDF_ENGINE)
options = {
"engine": NETCDF_ENGINE,
**options,
}
elif ZARR_SUFFIX in suffixes:
engine = kwargs.pop("engine", ZARR_ENGINE)
options = {
"chunks": "auto",
"engine": ZARR_ENGINE,
**options,
}
else:
raise ValueError(
f"File type of {dems} is not supported."
"Use netCDF (.nc) or Zarr (.zarr, .zarr.zip)."
)

return xr.open_dataarray(dems, engine=engine, **kwargs)
return xr.open_dataarray(dems, **options)
6 changes: 3 additions & 3 deletions tests/test_io.py → tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


# dependencies
from decode import io
from decode import load
from pytest import mark


Expand All @@ -14,5 +14,5 @@

# test functions
@mark.parametrize("dems", DEMS_ALL)
def test_open_dems(dems: Path) -> None:
io.open_dems(dems)
def test_load_dems(dems: Path) -> None:
load.dems(dems)
4 changes: 2 additions & 2 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

# dependencies
import matplotlib.pyplot as plt
from decode import io, plot
from decode import load, plot


# constants
DEMS_DIR = Path(__file__).parents[1] / "data" / "dems"
DEMS = io.open_dems(DEMS_DIR / "dems_20171111110002.nc.gz")
DEMS = load.dems(DEMS_DIR / "dems_20171111110002.nc.gz")


def test_plot_data_1d_time() -> None:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,39 @@


# dependencies
from decode import io, select
from decode import load, select


# constants
DEMS_DIR = Path(__file__).parents[1] / "data" / "dems"
DEMS = io.open_dems(DEMS_DIR / "dems_20171111110002.nc.gz")
DEMS = load.dems(DEMS_DIR / "dems_20171111110002.nc.gz")


def test_by_min() -> None:
def test_select_by_min() -> None:
min = -0.5
sel = select.by(DEMS[::100], "lon", min=min)
assert (sel.lon >= min).all()


def test_by_max() -> None:
def test_select_by_max() -> None:
max = +0.5
sel = select.by(DEMS[::100], "lon", max=max)
assert (sel.lon < max).all()


def test_by_range() -> None:
def test_select_by_range() -> None:
min, max = -0.5, +0.5
sel = select.by(DEMS[::100], "lon", min=min, max=max)
assert ((sel.lon >= min) & (sel.lon < max)).all()


def test_by_include() -> None:
def test_select_by_include() -> None:
include = ["SCAN", "TRAN"]
sel = select.by(DEMS[::100], "state", include=include)
assert set(sel.state.values) == set(include)


def test_by_exclude() -> None:
def test_select_by_exclude() -> None:
exclude = ["ACC"]
sel = select.by(DEMS[::100], "state", exclude=exclude)
assert set(sel.state.values).isdisjoint(exclude)

0 comments on commit 19fac04

Please sign in to comment.