diff --git a/src/akimbo/apply_tree.py b/src/akimbo/apply_tree.py index 9aed44d..a6531a7 100644 --- a/src/akimbo/apply_tree.py +++ b/src/akimbo/apply_tree.py @@ -8,11 +8,22 @@ import pyarrow as pa +def match_any(*layout, **_): + return True + + def leaf(*layout, **_): """True for the lowest elements of any akwward layout tree""" return layout[0].is_leaf +def numeric(*layout, **_): + return layout[0].is_leaf and layout[0].parameters.get("__array__", None) not in { + "string", + "char", + } + + def run_with_transform( arr: ak.Array, op, @@ -24,6 +35,8 @@ def run_with_transform( **kw, ) -> ak.Array: def func(layout, **kwargs): + from akimbo.utils import match_string + if not isinstance(layout, tuple): layout = (layout,) if all(match(lay, **(match_kwargs or {})) for lay in layout): @@ -34,11 +47,23 @@ def func(layout, **kwargs): elif inmode == "numpy": # works on numpy/cupy contents out = op(*(lay.data for lay in layout), **kw, **(match_kwargs or {})) - else: + elif inmode == "ak": out = op(*layout, **kw, **(match_kwargs or {})) - return outtype(out) if callable(outtype) else out + else: + out = op( + *(ak.Array(lay) for lay in layout), **kw, **(match_kwargs or {}) + ) + if callable(outtype): + return outtype(out) + elif isinstance(out, ak.Array): + return out.layout + else: + return out + if match_string(*layout): + # non-string op may fail to descend into string + return layout[0] - return ak.transform(func, arr, *others) + return ak.transform(func, arr, *others, allow_records=True) def dec( diff --git a/src/akimbo/mixin.py b/src/akimbo/mixin.py index 82faf73..c123ab6 100644 --- a/src/akimbo/mixin.py +++ b/src/akimbo/mixin.py @@ -7,7 +7,8 @@ import awkward as ak import pyarrow.compute as pc -from akimbo.apply_tree import dec +from akimbo.apply_tree import dec, match_any, numeric, run_with_transform +from akimbo.utils import to_ak_layout methods = [ _ for _ in (dir(ak)) if not _.startswith(("_", "ak_")) and not _[0].isupper() @@ -179,6 +180,9 @@ def apply(self, fn: Callable, where=None, **kwargs): The function should take an ak array as input and produce an ak array or scalar. + + Unlike ``transform``, the function takes and returns ak.Array instances + and acts on a whole schema tree. """ if where: bits = tuple(where.split(".")) @@ -190,6 +194,44 @@ def apply(self, fn: Callable, where=None, **kwargs): final = fn(self.array) return self.to_output(final) + def transform( + self, fn: Callable, *others, where=None, match=match_any, inmode="ak", **kwargs + ): + """Perform arbitrary function to selected parts of the data tree + + This process walks thought the data's schema tree, and applies the given + function only on the matching nodes. + + Parameters + ---------- + fn: the operation you want to perform. Typically unary or binary, and may take + extra kwargs + others: extra arguments, perhaps other akimbo series + where: path in the schema tree to apply this + match: when walking the schema, this determines if a node should be processed; + it will be a function taking one or more ak.contents classes. ak.apaply_tree + contains convenience matchers macth_any, leaf and numeric, and more matchers + can be found in the string and datetime modules + inmode: data should be passed to the given function as: + "arrow" | "numpy" (includes cupy) | "ak" layout | "array" high-level ak.Array + kwargs: passed to the operation, except those that are taken by ``run_with_transform``. + """ + if where: + bits = tuple(where.split(".")) + arr = self.array + part = arr.__getitem__(bits) + # TODO: apply ``where`` to any arrays in others + # other = [to_ak_layout(ar) for ar in others] + out = run_with_transform( + part, fn, match=match, others=others, inmode=inmode, **kwargs + ) + final = ak.with_field(arr, out, where=where) + else: + final = run_with_transform( + self.array, fn, match=match, others=others, inmode=inmode, **kwargs + ) + return self.to_output(final) + def __getitem__(self, item): out = self.array.__getitem__(item) return self.to_output(out) @@ -331,12 +373,31 @@ def join( def _create_op(cls, op): """Make functions to perform all the arithmetic, logical and comparison ops""" - def run(self, *args, **kwargs): - ar2 = (ar.ak.array if hasattr(ar, "ak") else ar for ar in args) - ar3 = (ar.array if isinstance(ar, cls) else ar for ar in ar2) - return self.to_output(op(self.array, *ar3, **kwargs)) + def op2(*args, extra=None, **kw): + args = list(args) + list(extra or []) + return op(*args, **kw) + + def f(self, *args, **kw): + # TODO: test here is for literals, but really we want "don't know how to + # array that" condition + extra = (_ for _ in args if isinstance(_, (str, int, float))) + args = ( + to_ak_layout(_) for _ in args if not isinstance(_, (str, int, float)) + ) + out = self.transform( + op2, + *args, + match=numeric, + inmode="numpy", + extra=extra, + outtype=ak.contents.NumpyArray, + **kw, + ) + if isinstance(self._obj, self.dataframe_type): + return out.ak.unmerge() + return out - return run + return f def __getattr__(self, item): arr = self.array diff --git a/src/akimbo/strings.py b/src/akimbo/strings.py index 6e7fcb2..1460154 100644 --- a/src/akimbo/strings.py +++ b/src/akimbo/strings.py @@ -8,10 +8,7 @@ from akimbo.apply_tree import dec from akimbo.mixin import Accessor - - -def match_string(*layout): - return layout[0].is_list and layout[0].parameter("__array__") == "string" +from akimbo.utils import match_string def _encode(layout): @@ -53,14 +50,21 @@ def _decode(layout): # make sensible defaults for strptime strptime = functools.wraps(pc.strptime)( - lambda *args, format="%FT%T", unit="s", error_is_null=True, **kw: - pc.strptime(*args, format=format, unit=unit, error_is_null=error_is_null) + lambda *args, format="%FT%T", unit="s", error_is_null=True, **kw: pc.strptime( + *args, format=format, unit=unit, error_is_null=error_is_null + ) ) class StringAccessor: """String operations on nested/var-length data""" + # TODO: implement dunder add (concat strings) and mul (repeat strings) + # - s.ak.str + "suffix" (and arguments swapped) + # - s.ak.str + s2.ak.str (with matching schemas) + # - s.ak.str * N (and arguments swapped) + # - s.ak.str * s (where each string maps to integers for variable repeats) + def __init__(self, accessor): self.accessor = accessor diff --git a/src/akimbo/utils.py b/src/akimbo/utils.py index cfda842..a5f1b89 100644 --- a/src/akimbo/utils.py +++ b/src/akimbo/utils.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +import awkward as ak + + class NoAttributes: """Allows importing akimbo.cudf even if cudf isn't installed @@ -20,3 +25,18 @@ def __call__(self, *args, **kwargs): __name__ = "DummyAttributesObject" __doc__ = None __annotations__ = None + + +def to_ak_layout(ar): + if hasattr(ar, "ak"): + return ar.ak.array + elif hasattr(ar, "array"): + return ar.array + elif isinstance(ar, (ak.Array)): + return ar + else: + return ak.Array(ak.to_layout(ar)) + + +def match_string(*layout): + return layout[0].is_list and layout[0].parameter("__array__") == "string" diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 13ae2b4..2580e59 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -49,6 +49,58 @@ def test_ufunc(): assert (s.ak + s.ak).tolist() == [[2, 4, 6], [8, 10], [12]] assert (s.ak + s).tolist() == [[2, 4, 6], [8, 10], [12]] + s = pd.DataFrame({"a": s}) + assert (s.ak + 1).a.tolist() == [[2, 3, 4], [5, 6], [7]] + + assert (s.ak + s.ak).a.tolist() == [[2, 4, 6], [8, 10], [12]] + assert (s.ak + s).a.tolist() == [[2, 4, 6], [8, 10], [12]] + + +def test_manual_ufunc(): + from akimbo.apply_tree import numeric + + df = pd.DataFrame( + {"a": [["hey", "hi", "ho"], [None], ["blah"]], "b": [[1, 2, 3], [4, 5], [6]]} + ) + df2 = df.ak.transform( + lambda x: x + 1, match=numeric, inmode="numpy", outtype=ak.contents.NumpyArray + ) + expected = [ + {"a": ["hey", "hi", "ho"], "b": [2, 3, 4]}, + {"a": [None], "b": [5, 6]}, + {"a": ["blah"], "b": [7]}, + ] + assert df2.tolist() == expected + + +def test_mixed_ufunc(): + # ufuncs are numeric only by default, doesn't touch strings + df = pd.DataFrame( + {"a": [["hey", "hi", "ho"], [None], ["blah"]], "b": [[1, 2, 3], [4, 5], [6]]} + ) + df2 = df.ak + 1 + expected = [ + {"a": ["hey", "hi", "ho"], "b": [2, 3, 4]}, + {"a": [None], "b": [5, 6]}, + {"a": ["blah"], "b": [7]}, + ] + assert df2.ak.tolist() == expected + + df2 = df.ak * 2 + expected = [ + {"a": ["hey", "hi", "ho"], "b": [2, 4, 6]}, + {"a": [None], "b": [8, 10]}, + {"a": ["blah"], "b": [12]}, + ] + assert df2.ak.tolist() == expected + df2 = 2 * df.ak + assert df2.ak.tolist() == expected + + df2 = df.ak == df.ak + expected = [[True, True, True], [True, True], [True]] + assert df2["b"].tolist() == expected + assert df2["a"].tolist() == df["a"].tolist() + def test_to_autoarrow(): a = [[1, 2, 3], [4, 5], [6]] diff --git a/tests/test_polars.py b/tests/test_polars.py index 56b5d93..e4180fc 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -54,3 +54,7 @@ def test_ufunc(): s2 = np.add(s.ak, 1) assert s2.to_list() == [[2, 3, 4], [], [5, 6]] + + df = pl.DataFrame({"a": s}) + df2 = df.ak + 1 + assert df2["a"].to_list() == [[2, 3, 4], [], [5, 6]]