Skip to content

Commit

Permalink
[FEA] Report all unsupported operations for a query in cudf.polars (#…
Browse files Browse the repository at this point in the history
…16960)

Closes #16690. The purpose of this PR is to list all of the unique operations that are unsupported by `cudf.polars` when running a query. 

1. Question: How to traverse the tree to report the error nodes? Should this be done upstream in Polars?
2. Instead of traversing the query afterwards, we should probably catch each unsupported feature as we translate the IR.

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #16960
  • Loading branch information
Matt711 authored Nov 12, 2024
1 parent 202c231 commit 043bcbd
Show file tree
Hide file tree
Showing 13 changed files with 297 additions and 211 deletions.
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from cudf_polars._version import __git_commit__, __version__
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir
from cudf_polars.dsl.translate import Translator

# Check we have a supported polars version
from cudf_polars.utils.versions import _ensure_polars_version
Expand All @@ -22,7 +22,7 @@

__all__: list[str] = [
"execute_with_cudf",
"translate_ir",
"Translator",
"__git_commit__",
"__version__",
]
32 changes: 24 additions & 8 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import rmm
from rmm._cuda import gpu

from cudf_polars.dsl.translate import translate_ir
from cudf_polars.dsl.translate import Translator

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -180,14 +180,30 @@ def execute_with_cudf(
)
try:
with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
nt.set_udf(
partial(
_callback,
translate_ir(nt),
device=device,
memory_resource=memory_resource,
translator = Translator(nt)
ir = translator.translate_ir()
ir_translation_errors = translator.errors
if len(ir_translation_errors):
# TODO: Display these errors in user-friendly way.
# tracked in https://github.com/rapidsai/cudf/issues/17051
unique_errors = sorted(set(ir_translation_errors), key=str)
error_message = "Query contained unsupported operations"
verbose_error_message = (
f"{error_message}\nThe errors were:\n{unique_errors}"
)
unsupported_ops_exception = NotImplementedError(
error_message, unique_errors
)
if bool(int(os.environ.get("POLARS_VERBOSE", 0))):
warnings.warn(verbose_error_message, UserWarning, stacklevel=2)
if raise_on_fail:
raise unsupported_ops_exception
else:
nt.set_udf(
partial(
_callback, ir, device=device, memory_resource=memory_resource
)
)
)
except exception as e:
if bool(int(os.environ.get("POLARS_VERBOSE", 0))):
warnings.warn(
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AggInfo,
Col,
ColRef,
ErrorExpr,
Expr,
NamedExpr,
)
Expand All @@ -36,6 +37,7 @@

__all__ = [
"Expr",
"ErrorExpr",
"NamedExpr",
"Literal",
"LiteralColumn",
Expand Down
11 changes: 11 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ def collect_agg(self, *, depth: int) -> AggInfo:
) # pragma: no cover; check_agg trips first


class ErrorExpr(Expr):
__slots__ = ("error",)
_non_child = ("dtype", "error")
error: str

def __init__(self, dtype: plc.DataType, error: str) -> None:
self.dtype = dtype
self.error = error
self.children = ()


class NamedExpr:
# NamedExpr does not inherit from Expr since it does not appear
# when evaluating expressions themselves, only when constructing
Expand Down
19 changes: 18 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

__all__ = [
"IR",
"ErrorNode",
"PythonScan",
"Scan",
"Cache",
Expand Down Expand Up @@ -212,6 +213,22 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
)


class ErrorNode(IR):
"""Represents an error translating the IR."""

__slots__ = ("error",)
_non_child = (
"schema",
"error",
)
error: str
"""The error."""

def __init__(self, schema: Schema, error: str):
self.schema = schema
self.error = error


class PythonScan(IR):
"""Representation of input from a python function."""

Expand Down Expand Up @@ -1532,7 +1549,7 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR):
raise NotImplementedError(
"Unpivot cannot cast all input columns to "
f"{self.schema[value_name].id()}"
)
) # pragma: no cover
self.options = (
tuple(indices),
tuple(pivotees),
Expand Down
Loading

0 comments on commit 043bcbd

Please sign in to comment.