Skip to content

Commit

Permalink
support multiargument dask-awkward
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Mar 22, 2022
1 parent 0fa081d commit 3ba22e9
Showing 1 changed file with 73 additions and 59 deletions.
132 changes: 73 additions & 59 deletions src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def _blocked_dak(data: Any, *, histref: bh.Histogram | None = None) -> bh.Histog
return clone(histref).fill(data)


def _blocked_dak_ma(*data: Any, histref: bh.Histogram | None = None) -> bh.Histogram:
return clone(histref).fill(*data)


def optimize(
dsk: Mapping,
keys: Hashable | list[Hashable] | set[Hashable],
Expand All @@ -210,12 +214,10 @@ def optimize(

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
else:
# Perform Blockwise optimizations for HLG input
dsk = optimize_blockwise(dsk, keys=keys)
dsk = fuse_roots(dsk, keys=keys) # type: ignore
dsk = dsk.cull(set(keys)) # type: ignore

dsk = optimize_blockwise(dsk, keys=keys)
dsk = fuse_roots(dsk, keys=keys) # type: ignore
dsk = dsk.cull(set(keys)) # type: ignore
return dsk


Expand Down Expand Up @@ -334,15 +336,8 @@ def __str__(self) -> str:

__repr__ = __str__

@property
def _args(self) -> tuple[HighLevelGraph, str, bh.Histogram]:
return (self.dask, self.name, self.histref)

def __getstate__(self) -> tuple[HighLevelGraph, str, bh.Histogram]:
return self._args

def __setstate__(self, state: tuple[HighLevelGraph, str, bh.Histogram]) -> None:
self._dask, self._name, self._histref = state
def __reduce__(self):
return (AggHistogram, (self._dask, self._name, self._histref))

def to_dask_array(
self, flow: bool = False, dd: bool = False
Expand Down Expand Up @@ -375,9 +370,15 @@ def to_boost(self) -> bh.Histogram:
"""
return self.compute()

def to_delayed(self) -> Delayed:
dsk = self.__dask_graph__()
return Delayed(self.name, dsk, layer=self._layer)
def to_delayed(self, optimize_graph: bool = True) -> Delayed:
keys = self.__dask_keys__()
graph = self.__dask_graph__()
layer = self.__dask_layers__()[0]
if optimize_graph:
graph = self.__dask_optimize__(graph, keys)
layer = f"delayed-{self.name}"
graph = HighLevelGraph.from_collections(layer, graph, dependencies=())
return Delayed(keys[0], graph, layer=layer)

def values(self, flow: bool = False) -> NDArray[Any]:
return self.to_boost().values(flow=flow)
Expand Down Expand Up @@ -454,7 +455,7 @@ def __init__(
self, dsk: HighLevelGraph, name: str, npartitions: int, histref: bh.Histogram
) -> None:
self._dask: HighLevelGraph = dsk
self._name = name
self._name: str = name
self._npartitions: int = npartitions
self._histref: bh.Histogram = histref

Expand Down Expand Up @@ -505,27 +506,36 @@ def __str__(self) -> str:

__repr__ = __str__

@property
def _args(self) -> tuple[HighLevelGraph, str, int, bh.Histogram]:
return (self.dask, self.name, self.npartitions, self.histref)

def __getstate__(self) -> tuple[HighLevelGraph, str, int, bh.Histogram]:
return self._args

def __setstate__(
self, state: tuple[HighLevelGraph, str, int, bh.Histogram]
) -> None:
self._dask, self._name, self._npartitions, self._histref = state
def __reduce__(self):
return (
PartitionedHistogram,
(
self._dask,
self._name,
self._npartitions,
self._histref,
),
)

@property
def histref(self) -> bh.Histogram:
"""boost_histogram.Histogram: reference histogram."""
return self._histref

def to_agg(self, split_every: int | None = None) -> AggHistogram:
def collapse(self, split_every: int | None = None) -> AggHistogram:
"""Translate into a reduced aggregated histogram."""
return _reduction(self, split_every=split_every)

def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]:
keys = self.__dask_keys__()
graph = self.__dask_graph__()
layer = self.__dask_layers__()[0]
if optimize_graph:
graph = self.__dask_optimize__(graph, keys)
layer = f"delayed-{self.name}"
graph = HighLevelGraph.from_collections(layer, graph, dependencies=())
return [Delayed(k, graph, layer=layer) for k in keys]


def _reduction(
ph: PartitionedHistogram,
Expand Down Expand Up @@ -607,11 +617,14 @@ def _partitioned_histogram(
sample: DaskCollection | None = None,
split_every: int | None = None,
) -> PartitionedHistogram:
name = f"hist-on-block-{tokenize(data, histref, weights, sample)}"
name = f"hist-on-block-{tokenize(data, histref, weights, sample, split_every)}"
data_is_df = is_dataframe_like(data[0])
data_is_dak = is_awkward_like(data[0])
_weight_sample_check(*data, weights=weights)
if len(data) == 1 and hasattr(data[0], "_typetracer"):
from dask_awkward.core import partitionwise_layer as pwlayer

# Single awkward array object.
if len(data) == 1 and data_is_dak:
from dask_awkward.core import partitionwise_layer as dak_pwl

x = data[0]
if weights is not None and sample is not None:
Expand All @@ -621,7 +634,9 @@ def _partitioned_histogram(
elif weights is None and sample is not None:
raise NotImplementedError()
else:
g = pwlayer(_blocked_dak, name, x, histref=histref)
g = dak_pwl(_blocked_dak, name, x, histref=histref)

# Single object, not a dataframe
elif len(data) == 1 and not data_is_df:
x = data[0]
if weights is not None and sample is not None:
Expand All @@ -634,6 +649,8 @@ def _partitioned_histogram(
g = partitionwise(_blocked_sa_s, name, x, sample, histref=histref)
else:
g = partitionwise(_blocked_sa, name, x, histref=histref)

# Single object, is a dataframe
elif len(data) == 1 and data_is_df:
x = data[0]
if weights is not None and sample is not None:
Expand All @@ -646,8 +663,20 @@ def _partitioned_histogram(
g = partitionwise(_blocked_df_s, name, x, sample, histref=histref)
else:
g = partitionwise(_blocked_df, name, x, histref=histref)

# Multiple objects
else:
if weights is not None and sample is not None:

# Awkward array collection detected as first argument
if data_is_dak:
from dask_awkward.core import partitionwise_layer as dak_pwl

if weights is None and sample is None:
g = dak_pwl(_blocked_dak_ma, name, *data, histref=histref)
else:
raise NotImplementedError()
# Not an awkward array collection
elif weights is not None and sample is not None:
g = partitionwise(
_blocked_ma_w_s, name, *data, weights, sample, histref=histref
)
Expand All @@ -663,22 +692,6 @@ def _partitioned_histogram(
return PartitionedHistogram(hlg, name, data[0].npartitions, histref=histref)


def _reduced_histogram(
*data: DaskCollection,
histref: bh.Histogram,
weights: DaskCollection | None = None,
sample: DaskCollection | None = None,
split_every: int | None = None,
) -> AggHistogram:
ph = _partitioned_histogram(
*data,
histref=histref,
weights=weights,
sample=sample,
)
return ph.to_agg(split_every=split_every)


def to_dask_array(
agghist: AggHistogram,
flow: bool = False,
Expand Down Expand Up @@ -888,11 +901,12 @@ def factory(
if storage is None:
storage = bh.storage.Double()
histref = bh.Histogram(*axes, storage=storage) # type: ignore
f = _partitioned_histogram if keep_partitioned else _reduced_histogram
return f( # type: ignore
*data,
histref=histref,
weights=weights,
sample=sample,
split_every=split_every,
)

ph = _partitioned_histogram(*data, histref=histref, weights=weights, sample=sample)
if keep_partitioned:
return ph
return ph.collapse(split_every=split_every)


def is_awkward_like(x: Any) -> bool:
return is_dask_collection(x) and hasattr(x, "_typetracer")

0 comments on commit 3ba22e9

Please sign in to comment.