-
Notifications
You must be signed in to change notification settings - Fork 903
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add conversion from cudf-polars expressions to libcudf ast for parque…
…t filters (#17141) Previously, we always applied parquet filters by post-filtering. This negates much of the potential gain from having filters available at read time, namely discarding row groups. To fix this, implement, with the new visitor system of #17016, conversion to pylibcudf expressions. We must distinguish two types of expressions, ones that we can evaluate via `cudf::compute_column`, and the more restricted set of expressions that the parquet reader understands, this is handled by having a state that tracks the usage. The former style will be useful when we implement inequality joins. While here, extend the support in pylibcudf expressions to handle all supported literal types and expose `compute_column` so we can test the correctness of the broader (non-parquet) implementation. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #17141
- Loading branch information
Showing
15 changed files
with
552 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Conversion of expression nodes to libcudf AST nodes.""" | ||
|
||
from __future__ import annotations | ||
|
||
from functools import partial, reduce, singledispatch | ||
from typing import TYPE_CHECKING, TypeAlias | ||
|
||
import pylibcudf as plc | ||
from pylibcudf import expressions as plc_expr | ||
|
||
from polars.polars import _expr_nodes as pl_expr | ||
|
||
from cudf_polars.dsl import expr | ||
from cudf_polars.dsl.traversal import CachingVisitor | ||
from cudf_polars.typing import GenericTransformer | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Mapping | ||
|
||
# Can't merge these op-mapping dictionaries because scoped enum values | ||
# are exposed by cython with equality/hash based one their underlying | ||
# representation type. So in a dict they are just treated as integers. | ||
BINOP_TO_ASTOP = { | ||
plc.binaryop.BinaryOperator.EQUAL: plc_expr.ASTOperator.EQUAL, | ||
plc.binaryop.BinaryOperator.NULL_EQUALS: plc_expr.ASTOperator.NULL_EQUAL, | ||
plc.binaryop.BinaryOperator.NOT_EQUAL: plc_expr.ASTOperator.NOT_EQUAL, | ||
plc.binaryop.BinaryOperator.LESS: plc_expr.ASTOperator.LESS, | ||
plc.binaryop.BinaryOperator.LESS_EQUAL: plc_expr.ASTOperator.LESS_EQUAL, | ||
plc.binaryop.BinaryOperator.GREATER: plc_expr.ASTOperator.GREATER, | ||
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc_expr.ASTOperator.GREATER_EQUAL, | ||
plc.binaryop.BinaryOperator.ADD: plc_expr.ASTOperator.ADD, | ||
plc.binaryop.BinaryOperator.SUB: plc_expr.ASTOperator.SUB, | ||
plc.binaryop.BinaryOperator.MUL: plc_expr.ASTOperator.MUL, | ||
plc.binaryop.BinaryOperator.DIV: plc_expr.ASTOperator.DIV, | ||
plc.binaryop.BinaryOperator.TRUE_DIV: plc_expr.ASTOperator.TRUE_DIV, | ||
plc.binaryop.BinaryOperator.FLOOR_DIV: plc_expr.ASTOperator.FLOOR_DIV, | ||
plc.binaryop.BinaryOperator.PYMOD: plc_expr.ASTOperator.PYMOD, | ||
plc.binaryop.BinaryOperator.BITWISE_AND: plc_expr.ASTOperator.BITWISE_AND, | ||
plc.binaryop.BinaryOperator.BITWISE_OR: plc_expr.ASTOperator.BITWISE_OR, | ||
plc.binaryop.BinaryOperator.BITWISE_XOR: plc_expr.ASTOperator.BITWISE_XOR, | ||
plc.binaryop.BinaryOperator.LOGICAL_AND: plc_expr.ASTOperator.LOGICAL_AND, | ||
plc.binaryop.BinaryOperator.LOGICAL_OR: plc_expr.ASTOperator.LOGICAL_OR, | ||
plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: plc_expr.ASTOperator.NULL_LOGICAL_AND, | ||
plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: plc_expr.ASTOperator.NULL_LOGICAL_OR, | ||
} | ||
|
||
UOP_TO_ASTOP = { | ||
plc.unary.UnaryOperator.SIN: plc_expr.ASTOperator.SIN, | ||
plc.unary.UnaryOperator.COS: plc_expr.ASTOperator.COS, | ||
plc.unary.UnaryOperator.TAN: plc_expr.ASTOperator.TAN, | ||
plc.unary.UnaryOperator.ARCSIN: plc_expr.ASTOperator.ARCSIN, | ||
plc.unary.UnaryOperator.ARCCOS: plc_expr.ASTOperator.ARCCOS, | ||
plc.unary.UnaryOperator.ARCTAN: plc_expr.ASTOperator.ARCTAN, | ||
plc.unary.UnaryOperator.SINH: plc_expr.ASTOperator.SINH, | ||
plc.unary.UnaryOperator.COSH: plc_expr.ASTOperator.COSH, | ||
plc.unary.UnaryOperator.TANH: plc_expr.ASTOperator.TANH, | ||
plc.unary.UnaryOperator.ARCSINH: plc_expr.ASTOperator.ARCSINH, | ||
plc.unary.UnaryOperator.ARCCOSH: plc_expr.ASTOperator.ARCCOSH, | ||
plc.unary.UnaryOperator.ARCTANH: plc_expr.ASTOperator.ARCTANH, | ||
plc.unary.UnaryOperator.EXP: plc_expr.ASTOperator.EXP, | ||
plc.unary.UnaryOperator.LOG: plc_expr.ASTOperator.LOG, | ||
plc.unary.UnaryOperator.SQRT: plc_expr.ASTOperator.SQRT, | ||
plc.unary.UnaryOperator.CBRT: plc_expr.ASTOperator.CBRT, | ||
plc.unary.UnaryOperator.CEIL: plc_expr.ASTOperator.CEIL, | ||
plc.unary.UnaryOperator.FLOOR: plc_expr.ASTOperator.FLOOR, | ||
plc.unary.UnaryOperator.ABS: plc_expr.ASTOperator.ABS, | ||
plc.unary.UnaryOperator.RINT: plc_expr.ASTOperator.RINT, | ||
plc.unary.UnaryOperator.BIT_INVERT: plc_expr.ASTOperator.BIT_INVERT, | ||
plc.unary.UnaryOperator.NOT: plc_expr.ASTOperator.NOT, | ||
} | ||
|
||
SUPPORTED_STATISTICS_BINOPS = { | ||
plc.binaryop.BinaryOperator.EQUAL, | ||
plc.binaryop.BinaryOperator.NOT_EQUAL, | ||
plc.binaryop.BinaryOperator.LESS, | ||
plc.binaryop.BinaryOperator.LESS_EQUAL, | ||
plc.binaryop.BinaryOperator.GREATER, | ||
plc.binaryop.BinaryOperator.GREATER_EQUAL, | ||
} | ||
|
||
REVERSED_COMPARISON = { | ||
plc.binaryop.BinaryOperator.EQUAL: plc.binaryop.BinaryOperator.EQUAL, | ||
plc.binaryop.BinaryOperator.NOT_EQUAL: plc.binaryop.BinaryOperator.NOT_EQUAL, | ||
plc.binaryop.BinaryOperator.LESS: plc.binaryop.BinaryOperator.GREATER, | ||
plc.binaryop.BinaryOperator.LESS_EQUAL: plc.binaryop.BinaryOperator.GREATER_EQUAL, | ||
plc.binaryop.BinaryOperator.GREATER: plc.binaryop.BinaryOperator.LESS, | ||
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc.binaryop.BinaryOperator.LESS_EQUAL, | ||
} | ||
|
||
|
||
Transformer: TypeAlias = GenericTransformer[expr.Expr, plc_expr.Expression] | ||
|
||
|
||
@singledispatch | ||
def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression: | ||
""" | ||
Translate an expression to a pylibcudf Expression. | ||
Parameters | ||
---------- | ||
node | ||
Expression to translate. | ||
self | ||
Recursive transformer. The state dictionary should contain a | ||
`for_parquet` key indicating if this transformation should | ||
provide an expression suitable for use in parquet filters. | ||
If `for_parquet` is `False`, the dictionary should contain a | ||
`name_to_index` mapping that maps column names to their | ||
integer index in the table that will be used for evaluation of | ||
the expression. | ||
Returns | ||
------- | ||
pylibcudf Expression. | ||
Raises | ||
------ | ||
NotImplementedError or KeyError if the expression cannot be translated. | ||
""" | ||
raise NotImplementedError(f"Unhandled expression type {type(node)}") | ||
|
||
|
||
@_to_ast.register | ||
def _(node: expr.Col, self: Transformer) -> plc_expr.Expression: | ||
if self.state["for_parquet"]: | ||
return plc_expr.ColumnNameReference(node.name) | ||
return plc_expr.ColumnReference(self.state["name_to_index"][node.name]) | ||
|
||
|
||
@_to_ast.register | ||
def _(node: expr.Literal, self: Transformer) -> plc_expr.Expression: | ||
return plc_expr.Literal(plc.interop.from_arrow(node.value)) | ||
|
||
|
||
@_to_ast.register | ||
def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: | ||
if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS: | ||
return plc_expr.Operation( | ||
plc_expr.ASTOperator.NOT, | ||
self( | ||
# Reconstruct and apply, rather than directly | ||
# constructing the right expression so we get the | ||
# handling of parquet special cases for free. | ||
expr.BinOp( | ||
node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children | ||
) | ||
), | ||
) | ||
if self.state["for_parquet"]: | ||
op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children) | ||
if op1_col ^ op2_col: | ||
op = node.op | ||
if op not in SUPPORTED_STATISTICS_BINOPS: | ||
raise NotImplementedError( | ||
f"Parquet filter binop with column doesn't support {node.op!r}" | ||
) | ||
op1, op2 = node.children | ||
if op2_col: | ||
(op1, op2) = (op2, op1) | ||
op = REVERSED_COMPARISON[op] | ||
if not isinstance(op2, expr.Literal): | ||
raise NotImplementedError( | ||
"Parquet filter binops must have form 'col binop literal'" | ||
) | ||
return plc_expr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2)) | ||
elif op1_col and op2_col: | ||
raise NotImplementedError( | ||
"Parquet filter binops must have one column reference not two" | ||
) | ||
return plc_expr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children)) | ||
|
||
|
||
@_to_ast.register | ||
def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: | ||
if node.name == pl_expr.BooleanFunction.IsIn: | ||
needles, haystack = node.children | ||
if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: | ||
# 16 is an arbitrary limit | ||
needle_ref = self(needles) | ||
values = [ | ||
plc_expr.Literal(plc.interop.from_arrow(v)) for v in haystack.value | ||
] | ||
return reduce( | ||
partial(plc_expr.Operation, plc_expr.ASTOperator.LOGICAL_OR), | ||
( | ||
plc_expr.Operation(plc_expr.ASTOperator.EQUAL, needle_ref, value) | ||
for value in values | ||
), | ||
) | ||
if self.state["for_parquet"] and isinstance(node.children[0], expr.Col): | ||
raise NotImplementedError( | ||
f"Parquet filters don't support {node.name} on columns" | ||
) | ||
if node.name == pl_expr.BooleanFunction.IsNull: | ||
return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) | ||
elif node.name == pl_expr.BooleanFunction.IsNotNull: | ||
return plc_expr.Operation( | ||
plc_expr.ASTOperator.NOT, | ||
plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), | ||
) | ||
elif node.name == pl_expr.BooleanFunction.Not: | ||
return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) | ||
raise NotImplementedError(f"AST conversion does not support {node.name}") | ||
|
||
|
||
@_to_ast.register | ||
def _(node: expr.UnaryFunction, self: Transformer) -> plc_expr.Expression: | ||
if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]: | ||
raise NotImplementedError( | ||
"Parquet filters don't support {node.name} on columns" | ||
) | ||
return plc_expr.Operation( | ||
UOP_TO_ASTOP[node._OP_MAPPING[node.name]], self(node.children[0]) | ||
) | ||
|
||
|
||
def to_parquet_filter(node: expr.Expr) -> plc_expr.Expression | None: | ||
""" | ||
Convert an expression to libcudf AST nodes suitable for parquet filtering. | ||
Parameters | ||
---------- | ||
node | ||
Expression to convert. | ||
Returns | ||
------- | ||
pylibcudf Expression if conversion is possible, otherwise None. | ||
""" | ||
mapper = CachingVisitor(_to_ast, state={"for_parquet": True}) | ||
try: | ||
return mapper(node) | ||
except (KeyError, NotImplementedError): | ||
return None | ||
|
||
|
||
def to_ast( | ||
node: expr.Expr, *, name_to_index: Mapping[str, int] | ||
) -> plc_expr.Expression | None: | ||
""" | ||
Convert an expression to libcudf AST nodes suitable for compute_column. | ||
Parameters | ||
---------- | ||
node | ||
Expression to convert. | ||
name_to_index | ||
Mapping from column names to their index in the table that | ||
will be used for expression evaluation. | ||
Returns | ||
------- | ||
pylibcudf Expressoin if conversion is possible, otherwise None. | ||
""" | ||
mapper = CachingVisitor( | ||
_to_ast, state={"for_parquet": False, "name_to_index": name_to_index} | ||
) | ||
try: | ||
return mapper(node) | ||
except (KeyError, NotImplementedError): | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.