Skip to content

Commit

Permalink
Support more cases in manifest falsifiable filter (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhou Fang authored Dec 30, 2023
1 parent b9d04dc commit 896f966
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 37 deletions.
139 changes: 105 additions & 34 deletions python/src/space/core/manifests/falsifiable_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
https://vldb.org/pvldb/vol14/p3083-edara.pdf.
"""

from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple
from functools import partial

from absl import logging # type: ignore[import-untyped]
Expand Down Expand Up @@ -92,76 +92,147 @@ def _falsifiable_filter(filter_: ExtendedExpression, primary_keys: Set[str],
filter_.base_schema,
primary_keys,
field_name_ids,
filter_.referred_expr[0].expression.scalar_function)
filter_.referred_expr[0].expression)


# pylint: disable=too-many-locals,too-many-return-statements
def _falsifiable_filter_internal(
extensions: List[SimpleExtensionDeclaration], base_schema: NamedStruct,
primary_keys: Set[str], field_name_ids: Dict[str, int],
root: Expression.ScalarFunction) -> pc.Expression:
if len(root.arguments) != 2:
raise _ExpressionException(
f"Invalid number of arguments: {root.arguments}")
# pylint: disable=too-many-locals,too-many-return-statements,too-many-branches,too-many-statements
def _falsifiable_filter_internal(extensions: List[SimpleExtensionDeclaration],
base_schema: NamedStruct,
primary_keys: Set[str],
field_name_ids: Dict[str, int],
expr: Expression) -> pc.Expression:
if not _has_scalar_function(expr):
if _has_literal(expr):
return ~_value(expr)

fn = extensions[root.function_reference].extension_function.name
lhs = root.arguments[0].value
rhs = root.arguments[1].value
if _has_selection(expr):
raise _ExpressionException(
f"Single arg expression is not supported: {expr}")

falsifiable_filter_fn = partial(_falsifiable_filter_internal, extensions,
base_schema, primary_keys, field_name_ids)
min_max_fn = partial(_min_max, base_schema, primary_keys, field_name_ids)

scalar_fn = expr.scalar_function
fn = extensions[scalar_fn.function_reference].extension_function.name

if len(scalar_fn.arguments) == 1 and fn == "not":
return ~falsifiable_filter_fn(scalar_fn.arguments[0].value)

if len(scalar_fn.arguments) != 2:
raise _ExpressionException(
f"Invalid number of arguments: {scalar_fn.arguments}")

# TODO: to support one side has scalar function, e.g., False | (a > 1).
if _has_scalar_function(lhs) and _has_scalar_function(rhs):
lhs_fn = lhs.scalar_function
rhs_fn = rhs.scalar_function
lhs = scalar_fn.arguments[0].value
rhs = scalar_fn.arguments[1].value

# Supported case: expression [and, or] expression, start recursion.
if _has_scalar_function(lhs) or _has_scalar_function(rhs):
# TODO: to support more functions.
if fn == "and":
return falsifiable_filter_fn(lhs_fn) | falsifiable_filter_fn(
rhs_fn) # type: ignore[operator]
return falsifiable_filter_fn(lhs) | falsifiable_filter_fn(
rhs) # type: ignore[operator]
elif fn == "or":
return falsifiable_filter_fn(lhs_fn) & falsifiable_filter_fn(
rhs_fn) # type: ignore[operator]
return falsifiable_filter_fn(lhs) & falsifiable_filter_fn(
rhs) # type: ignore[operator]
else:
raise _ExpressionException(f"Unsupported fn: {fn}")

# Supported case: field [op] field
if _has_selection(lhs) and _has_selection(rhs):
raise _ExpressionException(f"Both args are fields: {root.arguments}")
l_min, l_max, l_is_pk = min_max_fn(lhs)
r_min, r_max, r_is_pk = min_max_fn(rhs)

if not (l_is_pk and r_is_pk):
return pc.scalar(False)

# TODO: to support more functions.
if fn == "gt":
return l_max <= r_min
if fn == "gte":
return l_max < r_min
elif fn == "lt":
return l_min >= r_max
elif fn == "lte":
return l_min > r_max
elif fn == "equal":
return (l_max < r_min) | (r_max < l_min)
elif fn == "not_equal":
return (l_max >= r_min) & (r_max >= l_min)

raise _ExpressionException(f"Unsupported fn: {fn}")

# Supported case: value [op] value
if _has_literal(lhs) and _has_literal(rhs):
raise _ExpressionException(f"Both args are constants: {root.arguments}")
lv, rv = _value(lhs), _value(rhs)

# TODO: to support more functions.
if fn == "gt":
return lv <= rv
if fn == "gte":
return lv < rv
elif fn == "lt":
return lv >= rv
elif fn == "lte":
return lv > rv
elif fn == "equal":
return lv != rv
elif fn == "not_equal":
return lv == rv

raise _ExpressionException(f"Unsupported fn: {fn}")

# Supported case: field [op] value
if not ((_has_selection(lhs) and _has_literal(rhs)) or
(_has_literal(lhs) and _has_selection(rhs))):
raise _ExpressionException("Fail to evaluate args for falsifiable filter: "
f"{expr.scalar_function.arguments}")

# Move literal to rhs.
if _has_selection(rhs):
tmp, lhs = lhs, rhs
rhs = tmp

field_index = lhs.selection.direct_reference.struct_field.field
field_name = base_schema.names[field_index]
# Only primary key fields have column statistics for falsifiable filter
# pruning.
if field_name not in primary_keys:
field_min, field_max, is_pk = min_max_fn(lhs)
if not is_pk:
return pc.scalar(False)

field_id = field_name_ids[field_name]
field_min, field_max = _stats_field_min(field_id), _stats_field_max(field_id)
value = pc.scalar(
getattr(
rhs.literal,
rhs.literal.WhichOneof("literal_type"))) # type: ignore[arg-type]
value = _value(rhs)

# TODO: to support more functions.
if fn == "gt":
return field_max <= value
if fn == "gte":
return field_max < value
elif fn == "lt":
return field_min >= value
elif fn == "lte":
return field_min > value
elif fn == "equal":
return (field_min > value) | (field_max < value)
elif fn == "not_equal":
return (field_min == value) & (field_max == value)

raise _ExpressionException(f"Unsupported fn: {fn}")


def _value(v):
return pc.scalar(
getattr(v.literal,
v.literal.WhichOneof("literal_type"))) # type: ignore[arg-type]


def _min_max(base_schema, primary_keys, field_name_ids,
v) -> Tuple[pc.Expression, pc.Expression, bool]:
field_index = v.selection.direct_reference.struct_field.field
field_name = base_schema.names[field_index]
field_id = field_name_ids[field_name]

# Only primary key supports falsifiable filter because of column stats.
is_pk = field_name in primary_keys
return _stats_field_min(field_id), _stats_field_max(field_id), is_pk


def _stats_field_min(field_id: int) -> pc.Expression:
return pc.field(schema_utils.stats_field_name(field_id), constants.MIN_FIELD)

Expand Down
15 changes: 12 additions & 3 deletions python/tests/core/manifests/test_falsifiable_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,19 @@
(pc.field("_STATS_f0", "_MAX") <= 10) |
((pc.field("_STATS_f1", "_MIN") > 1) |
(pc.field("_STATS_f1", "_MAX") < 1))),
((pc.field("a") != 10), (pc.field("_STATS_f0", "_MIN") == 10) &
(pc.field("_STATS_f0", "_MAX") == 10)),
# Only primary keys are used.
((pc.field("a") < 10) | (pc.field("c") > "a"),
(pc.field("_STATS_f0", "_MIN") >= 10) & False)
(pc.field("_STATS_f0", "_MIN") >= 10) & False),
# Corner cases.
(pc.scalar(False), ~pc.scalar(False)),
((pc.scalar(False) | (pc.field("a") <= 10)),
(~pc.scalar(False) & (pc.field("_STATS_f0", "_MIN") > 10))),
(~(pc.field("a") >= 10), ~(pc.field("_STATS_f0", "_MAX") < 10)),
(pc.field("a") > pc.field("a"), pc.field("_STATS_f0", "_MAX")
<= pc.field("_STATS_f0", "_MIN")),
(pc.scalar(1) < pc.scalar(2), pc.scalar(1) >= pc.scalar(2))
])
def test_build_manifest_filter(filter_, expected_falsifiable_filter):
arrow_schema = pa.schema([("a", pa.int64()), ("b", pa.float64()),
Expand All @@ -43,8 +53,7 @@ def test_build_manifest_filter(filter_, expected_falsifiable_filter):
assert str(manifest_filter) == str(~expected_falsifiable_filter)


@pytest.mark.parametrize("filter_", [(pc.field("a") != 10),
(~(pc.field("a") > 10))])
@pytest.mark.parametrize("filter_", [pc.field("a")])
def test_build_manifest_filter_not_supported_return_none(filter_):
arrow_schema = pa.schema([("a", pa.int64()), ("b", pa.float64())])
field_name_ids = {"a": 0, "b": 1}
Expand Down

0 comments on commit 896f966

Please sign in to comment.