Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Report all unsupported operations for a query in cudf.polars #16960

Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from cudf_polars._version import __git_commit__, __version__
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir
from cudf_polars.dsl.translate import Translator

# Check we have a supported polars version
from cudf_polars.utils.versions import _ensure_polars_version
Expand All @@ -22,7 +22,7 @@

__all__: list[str] = [
"execute_with_cudf",
"translate_ir",
"Translator",
"__git_commit__",
"__version__",
]
32 changes: 24 additions & 8 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import rmm
from rmm._cuda import gpu

from cudf_polars.dsl.translate import translate_ir
from cudf_polars.dsl.translate import Translator

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -180,14 +180,30 @@ def execute_with_cudf(
)
try:
with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
nt.set_udf(
partial(
_callback,
translate_ir(nt),
device=device,
memory_resource=memory_resource,
translator = Translator(nt)
ir = translator.translate_ir()
ir_translation_errors = translator.errors
if len(ir_translation_errors):
# TODO: Display these errors in user-friendly way.
# tracked in https://github.com/rapidsai/cudf/issues/17051
unique_errors = sorted(set(ir_translation_errors), key=str)
error_message = "Query contained unsupported operations"
verbose_error_message = (
f"{error_message}\nThe errors were:\n{unique_errors}"
)
unsupported_ops_exception = NotImplementedError(
error_message, unique_errors
)
if bool(int(os.environ.get("POLARS_VERBOSE", 0))):
warnings.warn(verbose_error_message, UserWarning, stacklevel=2)
if raise_on_fail:
raise unsupported_ops_exception
else:
nt.set_udf(
partial(
_callback, ir, device=device, memory_resource=memory_resource
)
)
)
except exception as e:
if bool(int(os.environ.get("POLARS_VERBOSE", 0))):
warnings.warn(
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cudf_polars.dsl.expressions.base import (
AggInfo,
Col,
ErrorExpr,
Expr,
NamedExpr,
)
Expand All @@ -35,6 +36,7 @@

__all__ = [
"Expr",
"ErrorExpr",
"NamedExpr",
"Literal",
"LiteralColumn",
Expand Down
11 changes: 11 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ def collect_agg(self, *, depth: int) -> AggInfo:
) # pragma: no cover; check_agg trips first


class ErrorExpr(Expr):
__slots__ = ("error",)
_non_child = ("dtype", "error")
error: str

def __init__(self, dtype: plc.DataType, error: str) -> None:
self.dtype = dtype
self.error = error
self.children = ()


class NamedExpr:
# NamedExpr does not inherit from Expr since it does not appear
# when evaluating expressions themselves, only when constructing
Expand Down
17 changes: 17 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

__all__ = [
"IR",
"ErrorNode",
"PythonScan",
"Scan",
"Cache",
Expand Down Expand Up @@ -210,6 +211,22 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
)


class ErrorNode(IR):
"""Represents an error translating the IR."""

__slots__ = ("error",)
_non_child = (
"schema",
"error",
)
error: str
"""The error."""

def __init__(self, schema: Schema, error: str):
self.schema = schema
self.error = error


class PythonScan(IR):
"""Representation of input from a python function."""

Expand Down
Loading
Loading