Skip to content

Commit

Permalink
Register builtin functions and handle return types
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Apr 18, 2024
1 parent 01d22b9 commit 3097962
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 19 deletions.
62 changes: 53 additions & 9 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@

from substrait import proto


SQL_UNARY_FUNCTIONS = {"not": "not"}
SQL_BINARY_FUNCTIONS = {
# Arithmetic
"add": "add",
"div": "div",
"mul": "mul",
"sub": "sub",
"mod": "modulus",
"bitwiseand": "bitwise_and",
"bitwiseor": "bitwise_or",
"bitwisexor": "bitwise_xor",
"bitwiseor": "bitwise_or",
# Comparisons
"eq": "equal",
"nullsafeeq": "is_not_distinct_from",
"new": "not_equal",
"gt": "gt",
"gte": "gte",
"lt": "lt",
"lte": "lte",
# logical
"and": "and",
"or": "or",
}


Expand Down Expand Up @@ -124,6 +138,17 @@ def _parse_expression(self, expr, invoked_functions):
expr.this, invoked_functions
)
return expr.output_name, aliased_type, aliased_expr
elif expr.key in SQL_UNARY_FUNCTIONS:
argument_name, argument_type, argument = self._parse_expression(
expr.this, invoked_functions
)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(function_name, argument_type, argument)
)
invoked_functions.add(signature)
result_name = f"{function_name}_{argument_name}_{next(self._counter)}"
return result_name, result_type, function_expression
elif expr.key in SQL_BINARY_FUNCTIONS:
left_name, left_type, left = self._parse_expression(
expr.left, invoked_functions
Expand All @@ -148,26 +173,45 @@ def _parse_expression(self, expr, invoked_functions):
)

def _parse_function_invokation(
self, function_name, left_type, left, right_type, right
self, function_name, left_type, left, right_type=None, right=None
):
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
binary = False
argtypes = [left_type]
if right_type or right:
binary = True
argtypes.append(right_type)
signature = self._functions_catalog.signature(function_name, argtypes)

try:
function_anchor = self._functions_catalog.function_anchor(signature)
except KeyError:
# No function found with the exact types, try any1_any1 version
# TODO: What about cases like i32_any1? What about any instead of any1?
signature = f"{function_name}:any1_any1"
if binary:
signature = f"{function_name}:any1_any1"
else:
signature = f"{function_name}:any1"
function_anchor = self._functions_catalog.function_anchor(signature)

function_return_type = self._functions_catalog.function_return_type(signature)
if function_return_type is None:
print("No return type for", signature)
# TODO: Is this the right way to handle this?
function_return_type = left_type
return (
signature,
left_type, # TODO: Get the actually returned type from the functions catalog.
function_return_type,
proto.Expression(
scalar_function=proto.Expression.ScalarFunction(
function_reference=function_anchor,
arguments=[
proto.FunctionArgument(value=left),
proto.FunctionArgument(value=right),
],
arguments=(
[
proto.FunctionArgument(value=left),
proto.FunctionArgument(value=right),
]
if binary
else [proto.FunctionArgument(value=left)]
),
)
),
)
83 changes: 73 additions & 10 deletions src/substrait/sql/functions_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class FunctionsCatalog:
"/functions_arithmetic_decimal.yaml",
"/functions_boolean.yaml",
"/functions_comparison.yaml",
"/functions_datetime.yaml",
# "/functions_datetime.yaml", for now skip, it has duplicated functions
"/functions_geometry.yaml",
"/functions_logarithmic.yaml",
"/functions_rounding.yaml",
Expand All @@ -32,9 +32,10 @@ class FunctionsCatalog:
)

def __init__(self):
self._declarations = {}
self._registered_extensions = {}
self._functions = {}
self._functions_return_type = {}
self._register_builtins()

def load_standard_extensions(self, dirpath):
for ext in self.STANDARD_EXTENSIONS:
Expand All @@ -45,6 +46,7 @@ def load(self, dirpath, filename):
sections = yaml.safe_load(f)

loaded_functions = set()
functions_return_type = {}
for functions in sections.values():
for function in functions:
function_name = function["name"]
Expand All @@ -55,12 +57,16 @@ def load(self, dirpath, filename):
signature = function_name
else:
signature = f"{function_name}:{'_'.join(argtypes)}"
self._declarations[signature] = filename
loaded_functions.add(signature)
functions_return_type[signature] = self._type_from_name(
impl["return"]
)

self._register_extensions(filename, loaded_functions)
self._register_extensions(filename, loaded_functions, functions_return_type)

def _register_extensions(self, extension_uri, loaded_functions):
def _register_extensions(
self, extension_uri, loaded_functions, functions_return_type
):
if extension_uri not in self._registered_extensions:
ext_anchor_id = len(self._registered_extensions) + 1
self._registered_extensions[extension_uri] = proto.SimpleExtensionURI(
Expand All @@ -70,14 +76,12 @@ def _register_extensions(self, extension_uri, loaded_functions):
for function in loaded_functions:
if function in self._functions:
extensions_by_anchor = self.extension_uris_by_anchor
function = self._functions[function]
existing_function = self._functions[function]
function_extension = extensions_by_anchor[
function.extension_uri_reference
existing_function.extension_uri_reference
].uri
# TODO: Support overloading of functions from different extensionUris.
continue
raise ValueError(
f"Duplicate function definition: {function.name} from {extension_uri}, already loaded from {function_extension}"
f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}"
)
extension_anchor = self._registered_extensions[
extension_uri
Expand All @@ -90,6 +94,48 @@ def _register_extensions(self, extension_uri, loaded_functions):
function_anchor=function_anchor,
)
)
self._functions_return_type[function] = functions_return_type[function]

def _register_builtins(self):
self._functions["not:boolean"] = (
proto.SimpleExtensionDeclaration.ExtensionFunction(
name="not",
function_anchor=len(self._functions) + 1,
)
)
self._functions_return_type["not:boolean"] = proto.Type(
bool=proto.Type.Boolean()
)

def _type_from_name(self, typename):
nullable = False
if typename.endswith("?"):
nullable = True

typename = typename.strip("?")
if typename in ("any", "any1"):
return None

if typename == "boolean":
# For some reason boolean is an exception to the naming convention
typename = "bool"

try:
type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[
typename
].message_type
except KeyError:
# TODO: improve resolution of complext type like LIST?<any>
print("Unsupported type", typename)
return None

type_class = getattr(proto.Type, type_descriptor.name)
nullability = (
proto.Type.Nullability.NULLABILITY_REQUIRED
if not nullable
else proto.Type.Nullability.NULLABILITY_NULLABLE
)
return proto.Type(**{typename: type_class(nullability=nullability)})

@property
def extension_uris_by_anchor(self):
Expand All @@ -106,14 +152,31 @@ def extension_uris(self):
def extensions(self):
return list(self._functions.values())

def signature(self, function_name, proto_argtypes):
def _normalize_arg_types(argtypes):
for argtype in argtypes:
kind = argtype.WhichOneof("kind")
if kind == "bool":
yield "boolean"
else:
yield kind

return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}"

def function_anchor(self, function):
return self._functions[function].function_anchor

def function_return_type(self, function):
return self._functions_return_type[function]

def extensions_for_functions(self, functions):
uris_anchors = set()
extensions = []
for f in functions:
ext = self._functions[f]
if not ext.extension_uri_reference:
# Built-in function
continue
uris_anchors.add(ext.extension_uri_reference)
extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext))

Expand Down

0 comments on commit 3097962

Please sign in to comment.