Skip to content

Commit

Permalink
feat: add dask_awkward wrapper to Correction and CompoundCorrection (#…
Browse files Browse the repository at this point in the history
…219)

* add awkward wrapper to CompoundCorrection, add dask_awkward wrapper to both

* emit error if dask_awkward version isn't sufficient

also make min awkward/dask_awkward versions more easily configurable

* error on dask.array

* typo

* english

---------

Co-authored-by: Nicholas Smith <nick.smith@cern.ch>
  • Loading branch information
lgray and nsmith- authored Feb 2, 2024
1 parent 5bfd1a3 commit 2ab0000
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ test =
scipy
awkward >=2.2.2;python_version>"3.7"
awkward <2;python_version<="3.7"
dask-awkward;python_version>"3.7"
dask-awkward >=2024.1.1;python_version>"3.7"
dev =
pytest >=4.6
pre-commit
Expand Down
85 changes: 79 additions & 6 deletions src/correctionlib/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import correctionlib._core
import correctionlib.version

_version_two = version.parse("2")
_min_version_ak = version.parse("2.0.0")
_min_version_dak = version.parse("2024.1.1")


def open_auto(filename: str) -> str:
Expand Down Expand Up @@ -58,9 +59,9 @@ def _call_as_numpy(
) -> Any:
import awkward

if version.parse(awkward.__version__) < _version_two:
if version.parse(awkward.__version__) < _min_version_ak:
raise RuntimeError(
f"""imported awkward is version {awkward.__version__} < 2.0.0
f"""imported awkward is version {awkward.__version__} < {str(_min_version_ak)}
If you cannot upgrade, try doing: ak.flatten(arrays) -> result = correction(arrays) -> ak.unflatten(result, counts)
"""
)
Expand Down Expand Up @@ -130,6 +131,49 @@ def _wrap_awkward(
return awkward.transform(tocall, *array_args)


def _call_dask_correction(
correction: Any,
*args: Union["numpy.ndarray[Any, Any]", str, int, float],
):
return _wrap_awkward(correction._base.evalv, *args)


def _wrap_dask_awkward(
correction: Any,
*args: Union["numpy.ndarray[Any, Any]", str, int, float],
) -> Any:
import dask.delayed
import dask_awkward

if version.parse(dask_awkward.__version__) < _min_version_dak:
raise RuntimeError(
f"""imported dask_awkward is version {dask_awkward.__version__} < {str(_min_version_dak)}
This version of dask_awkward includes several useful bugfixes and functionality extensions.
Please upgrade dask_awkward.
"""
)

if not hasattr(correction, "_delayed_correction"):
setattr( # noqa: B010
correction,
"_delayed_correction",
dask.delayed(correction),
)

correction_meta = _wrap_awkward(
correction._base.evalv,
*(arg._meta if isinstance(arg, dask_awkward.Array) else arg for arg in args),
)

return dask_awkward.map_partitions(
_call_dask_correction,
correction._delayed_correction,
*args,
meta=correction_meta,
label=correction._name,
)


class Correction:
"""High-level correction evaluator object
Expand Down Expand Up @@ -174,12 +218,22 @@ def evaluate(
self, *args: Union["numpy.ndarray[Any, Any]", str, int, float]
) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]:
# TODO: create a ufunc with numpy.vectorize in constructor?
if any(str(type(arg)).startswith("<class 'dask.array.") for arg in args):
raise TypeError(
"Correctionlib does not yet handle dask.array collections. "
"If you require this functionality (i.e. you cannot or do "
"not want to use dask_awkward/awkward arrays) please open an "
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
)
try:
vargs = [
numpy.asarray(arg)
for arg in args
if not isinstance(arg, (str, int, float))
]
except NotImplementedError:
if any(str(type(arg)).startswith("<class 'dask_awkward.") for arg in args):
return _wrap_dask_awkward(self, *args) # type: ignore
except (ValueError, TypeError):
if any(str(type(arg)).startswith("<class 'awkward.") for arg in args):
return _wrap_awkward(self._base.evalv, *args) # type: ignore
Expand Down Expand Up @@ -242,9 +296,28 @@ def evaluate(
self, *args: Union["numpy.ndarray[Any, Any]", str, int, float]
) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]:
# TODO: create a ufunc with numpy.vectorize in constructor?
vargs = [
numpy.asarray(arg) for arg in args if not isinstance(arg, (str, int, float))
]
if any(str(type(arg)).startswith("<class 'dask.array.") for arg in args):
raise TypeError(
"Correctionlib does not yet handle dask.array collections. "
"if you require this functionality (i.e. you cannot or do "
"not want to use dask_awkward/awkward arrays) please open an "
"issue at https://github.com/cms-nanoAOD/correctionlib/issues."
)
try:
vargs = [
numpy.asarray(arg)
for arg in args
if not isinstance(arg, (str, int, float))
]
except NotImplementedError:
if any(str(type(arg)).startswith("<class 'dask_awkward.") for arg in args):
return _wrap_dask_awkward(self, *args) # type: ignore
except (ValueError, TypeError):
if any(str(type(arg)).startswith("<class 'awkward.") for arg in args):
return _wrap_awkward(self._base.evalv, *args) # type: ignore
except Exception as err:
raise err

if vargs:
bargs = numpy.broadcast_arrays(*vargs)
oshape = bargs[0].shape
Expand Down
3 changes: 1 addition & 2 deletions tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def test_highlevel_dask(cset):
x = awkward.unflatten(numpy.ones(6), [3, 2, 1])
dx = dask_awkward.from_awkward(x, 3)

evaluate = dask_awkward.map_partitions(
sf.evaluate,
evaluate = sf.evaluate(
dx,
1.0,
)
Expand Down

0 comments on commit 2ab0000

Please sign in to comment.