-
Notifications
You must be signed in to change notification settings - Fork 903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Dask cuDF legacy code #17205
Changes from all commits
da9b6fb
196c0c0
5be712b
957299b
86b3fa6
60853f6
56644e8
db4556f
8dec940
e54807f
ec139d2
b5f92c9
8fb4657
1d7c84e
ab56534
788cb24
1f54219
95a6697
1b36183
bc9897b
a84128a
ba3032a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,19 @@ | ||
# Copyright (c) 2018-2024, NVIDIA CORPORATION. | ||
|
||
from dask import config | ||
|
||
# For dask>2024.2.0, we can silence the loud deprecation | ||
# warning before importing `dask.dataframe` (this won't | ||
# do anything for dask==2024.2.0) | ||
config.set({"dataframe.query-planning-warning": False}) | ||
import warnings | ||
from importlib import import_module | ||
|
||
import dask.dataframe as dd # noqa: E402 | ||
from dask import config | ||
import dask.dataframe as dd | ||
from dask.dataframe import from_delayed # noqa: E402 | ||
|
||
import cudf # noqa: E402 | ||
|
||
from . import backends # noqa: E402, F401 | ||
from ._version import __git_commit__, __version__ # noqa: E402, F401 | ||
from .core import concat, from_cudf, from_dask_dataframe # noqa: E402 | ||
from .expr import QUERY_PLANNING_ON # noqa: E402 | ||
from .core import concat, from_cudf, DataFrame, Index, Series # noqa: F401 | ||
|
||
QUERY_PLANNING_ON = dd.DASK_EXPR_ENABLED | ||
|
||
|
||
def read_csv(*args, **kwargs): | ||
|
@@ -38,34 +36,51 @@ def read_parquet(*args, **kwargs): | |
return dd.read_parquet(*args, **kwargs) | ||
|
||
|
||
def raise_not_implemented_error(attr_name): | ||
def _deprecated_api(old_api, new_api=None, rec=None): | ||
def inner_func(*args, **kwargs): | ||
if new_api: | ||
# Use alternative | ||
msg = f"{old_api} is now deprecated. " | ||
msg += rec or f"Please use {new_api} instead." | ||
warnings.warn(msg, FutureWarning) | ||
new_attr = new_api.split(".") | ||
module = import_module(".".join(new_attr[:-1])) | ||
return getattr(module, new_attr[-1])(*args, **kwargs) | ||
|
||
# No alternative - raise an error | ||
raise NotImplementedError( | ||
f"Top-level {attr_name} API is not available for dask-expr." | ||
f"{old_api} is no longer supported. " + (rec or "") | ||
) | ||
|
||
return inner_func | ||
Comment on lines
-41
to
55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I decided to expand this utility to help deal with the many |
||
|
||
|
||
if QUERY_PLANNING_ON: | ||
from .expr._collection import DataFrame, Index, Series | ||
from ._expr.expr import _patch_dask_expr | ||
from . import io # noqa: F401 | ||
|
||
groupby_agg = raise_not_implemented_error("groupby_agg") | ||
groupby_agg = _deprecated_api("dask_cudf.groupby_agg") | ||
read_text = DataFrame.read_text | ||
to_orc = raise_not_implemented_error("to_orc") | ||
_patch_dask_expr() | ||
|
||
else: | ||
from .core import DataFrame, Index, Series # noqa: F401 | ||
from .groupby import groupby_agg # noqa: F401 | ||
from .io import read_text, to_orc # noqa: F401 | ||
from ._legacy.groupby import groupby_agg # noqa: F401 | ||
from ._legacy.io import read_text # noqa: F401 | ||
from . import io # noqa: F401 | ||
|
||
|
||
to_orc = _deprecated_api( | ||
"dask_cudf.to_orc", | ||
new_api="dask_cudf._legacy.io.to_orc", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The legacy to_orc API is actually "fine" (it's compatible with query-planning stuff), but we might as well discourage this kind of usage. |
||
rec="Please use DataFrame.to_orc instead.", | ||
) | ||
|
||
|
||
__all__ = [ | ||
"DataFrame", | ||
"Series", | ||
"Index", | ||
"from_cudf", | ||
"from_dask_dataframe", | ||
"concat", | ||
"from_delayed", | ||
] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
import functools | ||
|
||
import dask_expr._shuffle as _shuffle_module | ||
from dask_expr import new_collection | ||
from dask_expr._cumulative import CumulativeBlockwise | ||
from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns | ||
from dask_expr._reductions import Reduction, Var | ||
|
||
from dask.dataframe.core import ( | ||
is_dataframe_like, | ||
make_meta, | ||
meta_nonempty, | ||
) | ||
from dask.dataframe.dispatch import is_categorical_dtype | ||
from dask.typing import no_default | ||
|
||
import cudf | ||
|
||
## | ||
## Custom expressions | ||
## | ||
|
||
|
||
class RenameAxisCudf(RenameAxis): | ||
# TODO: Remove this after rename_axis is supported in cudf | ||
# (See: https://github.com/rapidsai/cudf/issues/16895) | ||
@staticmethod | ||
def operation(df, index=no_default, **kwargs): | ||
if index != no_default: | ||
df.index.name = index | ||
return df | ||
raise NotImplementedError( | ||
"Only `index` is supported for the cudf backend" | ||
) | ||
|
||
|
||
class ToCudfBackend(Elemwise): | ||
# TODO: Inherit from ToBackend when rapids-dask-dependency | ||
# is pinned to dask>=2024.8.1 | ||
_parameters = ["frame", "options"] | ||
_projection_passthrough = True | ||
_filter_passthrough = True | ||
_preserves_partitioning_information = True | ||
|
||
@staticmethod | ||
def operation(df, options): | ||
from dask_cudf.backends import to_cudf_dispatch | ||
|
||
return to_cudf_dispatch(df, **options) | ||
|
||
def _simplify_down(self): | ||
if isinstance( | ||
self.frame._meta, (cudf.DataFrame, cudf.Series, cudf.Index) | ||
): | ||
# We already have cudf data | ||
return self.frame | ||
|
||
|
||
## | ||
## Custom expression patching | ||
## | ||
|
||
|
||
# This can be removed after cudf#15176 is addressed. | ||
# See: https://github.com/rapidsai/cudf/issues/15176 | ||
class PatchCumulativeBlockwise(CumulativeBlockwise): | ||
@property | ||
def _args(self) -> list: | ||
return self.operands[:1] | ||
|
||
@property | ||
def _kwargs(self) -> dict: | ||
# Must pass axis and skipna as kwargs in cudf | ||
return {"axis": self.axis, "skipna": self.skipna} | ||
|
||
|
||
# The upstream Var code uses `Series.values`, and relies on numpy | ||
# for most of the logic. Unfortunately, cudf -> cupy conversion | ||
# is not supported for data containing null values. Therefore, | ||
# we must implement our own version of Var for now. This logic | ||
# is mostly copied from dask-cudf. | ||
|
||
|
||
class VarCudf(Reduction): | ||
# Uses the parallel version of Welford's online algorithm (Chan '79) | ||
# (http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf) | ||
_parameters = [ | ||
"frame", | ||
"skipna", | ||
"ddof", | ||
"numeric_only", | ||
"split_every", | ||
] | ||
_defaults = { | ||
"skipna": True, | ||
"ddof": 1, | ||
"numeric_only": False, | ||
"split_every": False, | ||
} | ||
|
||
@functools.cached_property | ||
def _meta(self): | ||
return make_meta( | ||
meta_nonempty(self.frame._meta).var( | ||
skipna=self.skipna, numeric_only=self.numeric_only | ||
) | ||
) | ||
|
||
@property | ||
def chunk_kwargs(self): | ||
return dict(skipna=self.skipna, numeric_only=self.numeric_only) | ||
|
||
@property | ||
def combine_kwargs(self): | ||
return {} | ||
|
||
@property | ||
def aggregate_kwargs(self): | ||
return dict(ddof=self.ddof) | ||
|
||
@classmethod | ||
def reduction_chunk(cls, x, skipna=True, numeric_only=False): | ||
kwargs = {"numeric_only": numeric_only} if is_dataframe_like(x) else {} | ||
if skipna or numeric_only: | ||
n = x.count(**kwargs) | ||
kwargs["skipna"] = skipna | ||
avg = x.mean(**kwargs) | ||
else: | ||
# Not skipping nulls, so might as well | ||
# avoid the full `count` operation | ||
n = len(x) | ||
kwargs["skipna"] = skipna | ||
avg = x.sum(**kwargs) / n | ||
if numeric_only: | ||
# Workaround for cudf bug | ||
# (see: https://github.com/rapidsai/cudf/issues/13731) | ||
x = x[n.index] | ||
m2 = ((x - avg) ** 2).sum(**kwargs) | ||
return n, avg, m2 | ||
|
||
@classmethod | ||
def reduction_combine(cls, parts): | ||
n, avg, m2 = parts[0] | ||
for i in range(1, len(parts)): | ||
n_a, avg_a, m2_a = n, avg, m2 | ||
n_b, avg_b, m2_b = parts[i] | ||
n = n_a + n_b | ||
avg = (n_a * avg_a + n_b * avg_b) / n | ||
delta = avg_b - avg_a | ||
m2 = m2_a + m2_b + delta**2 * n_a * n_b / n | ||
return n, avg, m2 | ||
|
||
@classmethod | ||
def reduction_aggregate(cls, vals, ddof=1): | ||
vals = cls.reduction_combine(vals) | ||
n, _, m2 = vals | ||
return m2 / (n - ddof) | ||
|
||
|
||
def _patched_var( | ||
self, | ||
axis=0, | ||
skipna=True, | ||
ddof=1, | ||
numeric_only=False, | ||
split_every=False, | ||
): | ||
if axis == 0: | ||
if hasattr(self._meta, "to_pandas"): | ||
return VarCudf(self, skipna, ddof, numeric_only, split_every) | ||
else: | ||
return Var(self, skipna, ddof, numeric_only, split_every) | ||
elif axis == 1: | ||
return VarColumns(self, skipna, ddof, numeric_only) | ||
else: | ||
raise ValueError(f"axis={axis} not supported. Please specify 0 or 1") | ||
|
||
|
||
# Temporary work-around for missing cudf + categorical support | ||
# See: https://github.com/rapidsai/cudf/issues/11795 | ||
# TODO: Fix RepartitionQuantiles and remove this in cudf>24.06 | ||
|
||
_original_get_divisions = _shuffle_module._get_divisions | ||
|
||
|
||
def _patched_get_divisions(frame, other, *args, **kwargs): | ||
# NOTE: The following two lines contains the "patch" | ||
# (we simply convert the partitioning column to pandas) | ||
if is_categorical_dtype(other._meta.dtype) and hasattr( | ||
other.frame._meta, "to_pandas" | ||
): | ||
other = new_collection(other).to_backend("pandas")._expr | ||
|
||
# Call "original" function | ||
return _original_get_divisions(frame, other, *args, **kwargs) | ||
|
||
|
||
_PATCHED = False | ||
|
||
|
||
def _patch_dask_expr(): | ||
global _PATCHED | ||
|
||
if not _PATCHED: | ||
CumulativeBlockwise._args = PatchCumulativeBlockwise._args | ||
CumulativeBlockwise._kwargs = PatchCumulativeBlockwise._kwargs | ||
Expr.var = _patched_var | ||
_shuffle_module._get_divisions = _patched_get_divisions | ||
_PATCHED = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer relevant.