Skip to content

Commit

Permalink
Merge pull request #341 from EvgSkv/ti2023
Browse files Browse the repository at this point in the history
Making functions compile in psql.
  • Loading branch information
EvgSkv authored Jun 21, 2024
2 parents bd52519 + e0e415f commit 248c9a1
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 1 deletion.
17 changes: 16 additions & 1 deletion compiler/universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,22 @@ def FunctionSql(self, name, allocator=None, internal_mode=False):
dialect=self.execution.dialect)
value_sql = ql.ConvertToSql(s.select['logica_value'])

sql = 'CREATE TEMP FUNCTION {name}({signature}) AS ({value})'.format(
# TODO: Move this to dialects.py.
if self.execution.annotations.Engine() == 'psql':
vartype = lambda varname: (
self.typing_engine.collector.psql_type_cache[
s.select[varname]['type']['rendered_type']])
sql = ('DROP FUNCTION IF EXISTS {name}; '
'CREATE OR REPLACE FUNCTION {name}({signature}) '
'RETURNS {value_type} AS $$ select ({value}) '
'$$ language sql'.format(
name=name,
signature=', '.join('%s %s' % (v, vartype(v))
for v in variables),
value_type=vartype('logica_value'),
value=value_sql))
else:
sql = 'CREATE TEMP FUNCTION {name}({signature}) AS ({value})'.format(
name=name,
signature=', '.join('%s ANY TYPE' % v for v in variables),
value=value_sql)
Expand Down
22 changes: 22 additions & 0 deletions integration_tests/psql_udf_test.l
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@Engine("psql");

F(x) --> 2 * x;
G(x, y) --> List{{a: x * y, b: i} :- i in Range(y + 1)};

@OrderBy(Test, "col0", "col1");
Test(i, j, F(i), G(i, j)) :- i in Range(3), j in Range(4);
16 changes: 16 additions & 0 deletions integration_tests/psql_udf_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
+------+------+------+-----------------------------------+
| col0 | col1 | col2 | col3 |
+------+------+------+-----------------------------------+
| 0 | 0 | 0 | {"(0,0)"} |
| 0 | 1 | 0 | {"(0,0)","(0,1)"} |
| 0 | 2 | 0 | {"(0,0)","(0,1)","(0,2)"} |
| 0 | 3 | 0 | {"(0,0)","(0,1)","(0,2)","(0,3)"} |
| 1 | 0 | 2 | {"(0,0)"} |
| 1 | 1 | 2 | {"(1,0)","(1,1)"} |
| 1 | 2 | 2 | {"(2,0)","(2,1)","(2,2)"} |
| 1 | 3 | 2 | {"(3,0)","(3,1)","(3,2)","(3,3)"} |
| 2 | 0 | 4 | {"(0,0)"} |
| 2 | 1 | 4 | {"(2,0)","(2,1)"} |
| 2 | 2 | 4 | {"(4,0)","(4,1)","(4,2)"} |
| 2 | 3 | 4 | {"(6,0)","(6,1)","(6,2)","(6,3)"} |
+------+------+------+-----------------------------------+
1 change: 1 addition & 0 deletions integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def RunAll(test_presto=False, test_trino=False):
RunTest("sqlite_reachability")
RunTest("sqlite_element_test")

RunTest("psql_udf_test")
RunTest("psql_flow_test")
RunTest("psql_graph_coloring_test")
RunTest("psql_win_move_test")
Expand Down
3 changes: 3 additions & 0 deletions type_inference/research/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def __init__(self, parsed_rules):
self.psql_type_definition = {}
self.definitions = []
self.typing_preamble = ''
self.psql_type_cache = {}

def ActPopulatingTypeMap(self, node):
if 'type' in node:
Expand All @@ -671,6 +672,8 @@ def ActPopulatingTypeMap(self, node):
node['type']['rendered_type'] = t_rendering
if 'combine' in node and reference_algebra.IsFullyDefined(t):
node['type']['combine_psql_type'] = self.PsqlType(t)
if reference_algebra.IsFullyDefined(t):
self.psql_type_cache[t_rendering] = self.PsqlType(t)
if isinstance(t, dict) and reference_algebra.IsFullyDefined(t):
node['type']['type_name'] = RecordTypeName(t_rendering)
if isinstance(t, list) and reference_algebra.IsFullyDefined(t):
Expand Down

0 comments on commit 248c9a1

Please sign in to comment.