Skip to content

Commit

Permalink
Add conversion from cudf-polars expressions to libcudf ast for parque…
Browse files Browse the repository at this point in the history
…t filters (#17141)

Previously, we always applied parquet filters by post-filtering. This negates much of the potential gain from having filters available at read time, namely discarding row groups. To fix this, implement, with the new visitor system of #17016, conversion to pylibcudf expressions.

We must distinguish two types of expressions, ones that we can evaluate via `cudf::compute_column`, and the more restricted set of expressions that the parquet reader understands, this is handled by having a state that tracks the usage. The former style will be useful when we implement inequality joins.

While here, extend the support in pylibcudf expressions to handle all supported literal types and expose `compute_column` so we can test the correctness of the broader (non-parquet) implementation.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17141
  • Loading branch information
wence- authored Oct 30, 2024
1 parent 0b9277b commit 7157de7
Show file tree
Hide file tree
Showing 15 changed files with 552 additions and 37 deletions.
24 changes: 5 additions & 19 deletions python/cudf/cudf/_lib/transform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,11 @@ from cudf.core._internals.expressions import parse_expression
from cudf.core.buffer import acquire_spill_lock, as_buffer
from cudf.utils import cudautils

from cython.operator cimport dereference
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

cimport pylibcudf.libcudf.transform as libcudf_transform
from pylibcudf cimport transform as plc_transform
from pylibcudf.expressions cimport Expression
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.expressions cimport expression
from pylibcudf.libcudf.table.table_view cimport table_view
from pylibcudf.libcudf.types cimport size_type

from cudf._lib.column cimport Column
from cudf._lib.utils cimport table_view_from_columns

import pylibcudf as plc

Expand Down Expand Up @@ -121,13 +112,8 @@ def compute_column(list columns, tuple column_names, expr: str):

# At the end, all the stack contains is the expression to evaluate.
cdef Expression cudf_expr = visitor.expression
cdef table_view tbl = table_view_from_columns(columns)
cdef unique_ptr[column] col
with nogil:
col = move(
libcudf_transform.compute_column(
tbl,
<expression &> dereference(cudf_expr.c_obj.get())
)
)
return Column.from_unique_ptr(move(col))
result = plc_transform.compute_column(
plc.Table([col.to_pylibcudf(mode="read") for col in columns]),
cudf_expr,
)
return Column.from_pylibcudf(result)
9 changes: 9 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame
from cudf_polars.dsl.nodebase import Node
from cudf_polars.dsl.to_ast import to_parquet_filter
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
Expand Down Expand Up @@ -418,9 +419,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
colnames[0],
)
elif self.typ == "parquet":
filters = None
if self.predicate is not None and self.row_index is None:
# Can't apply filters during read if we have a row index.
filters = to_parquet_filter(self.predicate.value)
tbl_w_meta = plc.io.parquet.read_parquet(
plc.io.SourceInfo(self.paths),
columns=with_columns,
filters=filters,
nrows=n_rows,
skip_rows=self.skip_rows,
)
Expand All @@ -429,6 +435,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
# TODO: consider nested column names?
tbl_w_meta.column_names(include_children=False),
)
if filters is not None:
# Mask must have been applied.
return df
elif self.typ == "ndjson":
json_schema: list[tuple[str, str, list]] = [
(name, typ, []) for name, typ in self.schema.items()
Expand Down
265 changes: 265 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Conversion of expression nodes to libcudf AST nodes."""

from __future__ import annotations

from functools import partial, reduce, singledispatch
from typing import TYPE_CHECKING, TypeAlias

import pylibcudf as plc
from pylibcudf import expressions as plc_expr

from polars.polars import _expr_nodes as pl_expr

from cudf_polars.dsl import expr
from cudf_polars.dsl.traversal import CachingVisitor
from cudf_polars.typing import GenericTransformer

if TYPE_CHECKING:
from collections.abc import Mapping

# Can't merge these op-mapping dictionaries because scoped enum values
# are exposed by cython with equality/hash based one their underlying
# representation type. So in a dict they are just treated as integers.
BINOP_TO_ASTOP = {
plc.binaryop.BinaryOperator.EQUAL: plc_expr.ASTOperator.EQUAL,
plc.binaryop.BinaryOperator.NULL_EQUALS: plc_expr.ASTOperator.NULL_EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL: plc_expr.ASTOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS: plc_expr.ASTOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL: plc_expr.ASTOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER: plc_expr.ASTOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc_expr.ASTOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.ADD: plc_expr.ASTOperator.ADD,
plc.binaryop.BinaryOperator.SUB: plc_expr.ASTOperator.SUB,
plc.binaryop.BinaryOperator.MUL: plc_expr.ASTOperator.MUL,
plc.binaryop.BinaryOperator.DIV: plc_expr.ASTOperator.DIV,
plc.binaryop.BinaryOperator.TRUE_DIV: plc_expr.ASTOperator.TRUE_DIV,
plc.binaryop.BinaryOperator.FLOOR_DIV: plc_expr.ASTOperator.FLOOR_DIV,
plc.binaryop.BinaryOperator.PYMOD: plc_expr.ASTOperator.PYMOD,
plc.binaryop.BinaryOperator.BITWISE_AND: plc_expr.ASTOperator.BITWISE_AND,
plc.binaryop.BinaryOperator.BITWISE_OR: plc_expr.ASTOperator.BITWISE_OR,
plc.binaryop.BinaryOperator.BITWISE_XOR: plc_expr.ASTOperator.BITWISE_XOR,
plc.binaryop.BinaryOperator.LOGICAL_AND: plc_expr.ASTOperator.LOGICAL_AND,
plc.binaryop.BinaryOperator.LOGICAL_OR: plc_expr.ASTOperator.LOGICAL_OR,
plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: plc_expr.ASTOperator.NULL_LOGICAL_AND,
plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: plc_expr.ASTOperator.NULL_LOGICAL_OR,
}

UOP_TO_ASTOP = {
plc.unary.UnaryOperator.SIN: plc_expr.ASTOperator.SIN,
plc.unary.UnaryOperator.COS: plc_expr.ASTOperator.COS,
plc.unary.UnaryOperator.TAN: plc_expr.ASTOperator.TAN,
plc.unary.UnaryOperator.ARCSIN: plc_expr.ASTOperator.ARCSIN,
plc.unary.UnaryOperator.ARCCOS: plc_expr.ASTOperator.ARCCOS,
plc.unary.UnaryOperator.ARCTAN: plc_expr.ASTOperator.ARCTAN,
plc.unary.UnaryOperator.SINH: plc_expr.ASTOperator.SINH,
plc.unary.UnaryOperator.COSH: plc_expr.ASTOperator.COSH,
plc.unary.UnaryOperator.TANH: plc_expr.ASTOperator.TANH,
plc.unary.UnaryOperator.ARCSINH: plc_expr.ASTOperator.ARCSINH,
plc.unary.UnaryOperator.ARCCOSH: plc_expr.ASTOperator.ARCCOSH,
plc.unary.UnaryOperator.ARCTANH: plc_expr.ASTOperator.ARCTANH,
plc.unary.UnaryOperator.EXP: plc_expr.ASTOperator.EXP,
plc.unary.UnaryOperator.LOG: plc_expr.ASTOperator.LOG,
plc.unary.UnaryOperator.SQRT: plc_expr.ASTOperator.SQRT,
plc.unary.UnaryOperator.CBRT: plc_expr.ASTOperator.CBRT,
plc.unary.UnaryOperator.CEIL: plc_expr.ASTOperator.CEIL,
plc.unary.UnaryOperator.FLOOR: plc_expr.ASTOperator.FLOOR,
plc.unary.UnaryOperator.ABS: plc_expr.ASTOperator.ABS,
plc.unary.UnaryOperator.RINT: plc_expr.ASTOperator.RINT,
plc.unary.UnaryOperator.BIT_INVERT: plc_expr.ASTOperator.BIT_INVERT,
plc.unary.UnaryOperator.NOT: plc_expr.ASTOperator.NOT,
}

SUPPORTED_STATISTICS_BINOPS = {
plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL,
}

REVERSED_COMPARISON = {
plc.binaryop.BinaryOperator.EQUAL: plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL: plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS: plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.LESS_EQUAL: plc.binaryop.BinaryOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.GREATER: plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc.binaryop.BinaryOperator.LESS_EQUAL,
}


Transformer: TypeAlias = GenericTransformer[expr.Expr, plc_expr.Expression]


@singledispatch
def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression:
"""
Translate an expression to a pylibcudf Expression.
Parameters
----------
node
Expression to translate.
self
Recursive transformer. The state dictionary should contain a
`for_parquet` key indicating if this transformation should
provide an expression suitable for use in parquet filters.
If `for_parquet` is `False`, the dictionary should contain a
`name_to_index` mapping that maps column names to their
integer index in the table that will be used for evaluation of
the expression.
Returns
-------
pylibcudf Expression.
Raises
------
NotImplementedError or KeyError if the expression cannot be translated.
"""
raise NotImplementedError(f"Unhandled expression type {type(node)}")


@_to_ast.register
def _(node: expr.Col, self: Transformer) -> plc_expr.Expression:
if self.state["for_parquet"]:
return plc_expr.ColumnNameReference(node.name)
return plc_expr.ColumnReference(self.state["name_to_index"][node.name])


@_to_ast.register
def _(node: expr.Literal, self: Transformer) -> plc_expr.Expression:
return plc_expr.Literal(plc.interop.from_arrow(node.value))


@_to_ast.register
def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression:
if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS:
return plc_expr.Operation(
plc_expr.ASTOperator.NOT,
self(
# Reconstruct and apply, rather than directly
# constructing the right expression so we get the
# handling of parquet special cases for free.
expr.BinOp(
node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children
)
),
)
if self.state["for_parquet"]:
op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children)
if op1_col ^ op2_col:
op = node.op
if op not in SUPPORTED_STATISTICS_BINOPS:
raise NotImplementedError(
f"Parquet filter binop with column doesn't support {node.op!r}"
)
op1, op2 = node.children
if op2_col:
(op1, op2) = (op2, op1)
op = REVERSED_COMPARISON[op]
if not isinstance(op2, expr.Literal):
raise NotImplementedError(
"Parquet filter binops must have form 'col binop literal'"
)
return plc_expr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2))
elif op1_col and op2_col:
raise NotImplementedError(
"Parquet filter binops must have one column reference not two"
)
return plc_expr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children))


@_to_ast.register
def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression:
if node.name == pl_expr.BooleanFunction.IsIn:
needles, haystack = node.children
if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16:
# 16 is an arbitrary limit
needle_ref = self(needles)
values = [
plc_expr.Literal(plc.interop.from_arrow(v)) for v in haystack.value
]
return reduce(
partial(plc_expr.Operation, plc_expr.ASTOperator.LOGICAL_OR),
(
plc_expr.Operation(plc_expr.ASTOperator.EQUAL, needle_ref, value)
for value in values
),
)
if self.state["for_parquet"] and isinstance(node.children[0], expr.Col):
raise NotImplementedError(
f"Parquet filters don't support {node.name} on columns"
)
if node.name == pl_expr.BooleanFunction.IsNull:
return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0]))
elif node.name == pl_expr.BooleanFunction.IsNotNull:
return plc_expr.Operation(
plc_expr.ASTOperator.NOT,
plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])),
)
elif node.name == pl_expr.BooleanFunction.Not:
return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0]))
raise NotImplementedError(f"AST conversion does not support {node.name}")


@_to_ast.register
def _(node: expr.UnaryFunction, self: Transformer) -> plc_expr.Expression:
if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]:
raise NotImplementedError(
"Parquet filters don't support {node.name} on columns"
)
return plc_expr.Operation(
UOP_TO_ASTOP[node._OP_MAPPING[node.name]], self(node.children[0])
)


def to_parquet_filter(node: expr.Expr) -> plc_expr.Expression | None:
"""
Convert an expression to libcudf AST nodes suitable for parquet filtering.
Parameters
----------
node
Expression to convert.
Returns
-------
pylibcudf Expression if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(_to_ast, state={"for_parquet": True})
try:
return mapper(node)
except (KeyError, NotImplementedError):
return None


def to_ast(
node: expr.Expr, *, name_to_index: Mapping[str, int]
) -> plc_expr.Expression | None:
"""
Convert an expression to libcudf AST nodes suitable for compute_column.
Parameters
----------
node
Expression to convert.
name_to_index
Mapping from column names to their index in the table that
will be used for expression evaluation.
Returns
-------
pylibcudf Expressoin if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(
_to_ast, state={"for_parquet": False, "name_to_index": name_to_index}
)
try:
return mapper(node)
except (KeyError, NotImplementedError):
return None
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def assert_collect_raises(
collect_kwargs: dict[OptimizationArgs, bool] | None = None,
polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
):
) -> None:
"""
Assert that collecting the result of a query raises the expected exceptions.
Expand Down
6 changes: 3 additions & 3 deletions python/cudf_polars/cudf_polars/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections.abc import Mapping


def pytest_addoption(parser: pytest.Parser):
def pytest_addoption(parser: pytest.Parser) -> None:
"""Add plugin-specific options."""
group = parser.getgroup(
"cudf-polars", "Plugin to set GPU as default engine for polars tests"
Expand All @@ -28,7 +28,7 @@ def pytest_addoption(parser: pytest.Parser):
)


def pytest_configure(config: pytest.Config):
def pytest_configure(config: pytest.Config) -> None:
"""Enable use of this module as a pytest plugin to enable GPU collection."""
no_fallback = config.getoption("--cudf-polars-no-fallback")
collect = polars.LazyFrame.collect
Expand Down Expand Up @@ -172,7 +172,7 @@ def pytest_configure(config: pytest.Config):

def pytest_collection_modifyitems(
session: pytest.Session, config: pytest.Config, items: list[pytest.Item]
):
) -> None:
"""Mark known failing tests."""
if config.getoption("--cudf-polars-no-fallback"):
# Don't xfail tests if running without fallback
Expand Down
Loading

0 comments on commit 7157de7

Please sign in to comment.