From f7b8633a8536fe481a8a63af3d4a1da2d600b23a Mon Sep 17 00:00:00 2001 From: Pratik Joseph Dabre Date: Tue, 10 Dec 2024 13:29:12 -0800 Subject: [PATCH] [native] Integrate SPI with sidecar and add e2e native function validation tests --- .../nativeworker/NativeQueryRunnerUtils.java | 1 + .../PrestoNativeQueryRunnerUtils.java | 3 +- .../NativeFunctionDefinitionProvider.java | 38 +++- ...NativeFunctionNamespaceManagerFactory.java | 4 +- .../NativeSidecarPluginQueryRunnerUtils.java | 10 ++ ...SidecarPluginFunctionNamespaceManager.java | 168 ++++++++++++++++++ 6 files changed, 218 insertions(+), 6 deletions(-) create mode 100644 presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/functionNamespace/TestNativeSidecarPluginFunctionNamespaceManager.java diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java index 3afa67b275c2..abeae669f87e 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/NativeQueryRunnerUtils.java @@ -59,6 +59,7 @@ public static Map getNativeSidecarProperties() { return ImmutableMap.builder() .put("coordinator-sidecar-enabled", "true") + .put("list-built-in-functions-only", "false") .build(); } diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java index 589e8de462e1..f84af4b64fb5 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java @@ -502,7 +502,8 @@ public static Optional> getExternalWorkerLaunc if (isCoordinatorSidecarEnabled) { configProperties = format("%s%n" + - "native-sidecar=true%n", configProperties); + "native-sidecar=true%n" + + "presto.default-namespace=native.default%n", configProperties); } if (remoteFunctionServerUds.isPresent()) { diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionDefinitionProvider.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionDefinitionProvider.java index 2f1636a71b52..310cd9ebda2e 100644 --- a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionDefinitionProvider.java +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionDefinitionProvider.java @@ -13,17 +13,26 @@ */ package com.facebook.presto.sidecar.functionNamespace; +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.http.client.Request; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata; import com.facebook.presto.functionNamespace.UdfFunctionSignatureMap; +import com.facebook.presto.sidecar.ForSidecarInfo; +import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import java.net.URI; import java.util.List; import java.util.Map; +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.prepareGet; import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; import static java.util.Objects.requireNonNull; @@ -31,23 +40,44 @@ public class NativeFunctionDefinitionProvider implements FunctionDefinitionProvider { private static final Logger log = Logger.get(NativeFunctionDefinitionProvider.class); - private final JsonCodec>> nativeFunctionSignatureMapJsonCodec; + private final NodeManager nodeManager; + private final HttpClient httpClient; + private static final String FUNCTION_SIGNATURES_ENDPOINT = "/v1/functions"; @Inject - public NativeFunctionDefinitionProvider(JsonCodec>> nativeFunctionSignatureMapJsonCodec) + public NativeFunctionDefinitionProvider( + @ForSidecarInfo HttpClient httpClient, + JsonCodec>> nativeFunctionSignatureMapJsonCodec, + NodeManager nodeManager) { - this.nativeFunctionSignatureMapJsonCodec = requireNonNull(nativeFunctionSignatureMapJsonCodec, "nativeFunctionSignatureMapJsonCodec is null"); + this.nativeFunctionSignatureMapJsonCodec = + requireNonNull(nativeFunctionSignatureMapJsonCodec, "nativeFunctionSignatureMapJsonCodec is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.httpClient = requireNonNull(httpClient, "typeManager is null"); } @Override public UdfFunctionSignatureMap getUdfDefinition(NodeManager nodeManager) { try { - throw new UnsupportedOperationException(); + Request request = prepareGet().setUri(getSidecarLocation()).build(); + Map> nativeFunctionSignatureMap = httpClient.execute(request, createJsonResponseHandler(nativeFunctionSignatureMapJsonCodec)); + return new UdfFunctionSignatureMap(ImmutableMap.copyOf(nativeFunctionSignatureMap)); } catch (Exception e) { throw new PrestoException(INVALID_ARGUMENTS, "Failed to get functions from sidecar.", e); } } + + private URI getSidecarLocation() + { + Node sidecarNode = nodeManager.getSidecarNode(); + return HttpUriBuilder.uriBuilder() + .scheme("http") + .host(sidecarNode.getHost()) + .port(sidecarNode.getHostAndPort().getPort()) + .appendPath(FUNCTION_SIGNATURES_ENDPOINT) + .build(); + } } diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionNamespaceManagerFactory.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionNamespaceManagerFactory.java index 0a0e9bc21ac4..307a358a99f0 100644 --- a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionNamespaceManagerFactory.java +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionNamespaceManagerFactory.java @@ -15,6 +15,7 @@ import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutorsModule; +import com.facebook.presto.sidecar.NativeSidecarCommunicationModule; import com.facebook.presto.sidecar.NativeSidecarPlugin; import com.facebook.presto.spi.function.FunctionHandleResolver; import com.facebook.presto.spi.function.FunctionNamespaceManager; @@ -56,7 +57,8 @@ public FunctionNamespaceManager create(String catalogName, Map actualRows = actualResult.getMaterializedRows(); + assertEquals(actualRows.size(), FUNCTION_COUNT); + for (MaterializedRow actualRow : actualRows) { + List row = actualRow.getFields(); + // No namespace should be present on the functionNames + String functionName = row.get(0).toString(); + if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, functionName)) { + fail(format("Namespace match found for row: %s", row)); + } + + // function namespace should be present. + String fullFunctionName = row.get(5).toString(); + if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) { + continue; + } + fail(format("No namespace match found for row: %s", row)); + } + } + + @Test + public void testGeneralQueries() + { + assertQuery("SELECT substr(comment, 1, 10), length(comment), trim(comment) FROM orders"); + assertQuery("SELECT substr(comment, 1, 10), length(comment), ltrim(comment) FROM orders"); + assertQuery("SELECT substr(comment, 1, 10), length(comment), rtrim(comment) FROM orders"); + assertQuery("select lower(comment) from nation"); + assertQuery("SELECT trim(comment, ' ns'), ltrim(comment, 'a b c'), rtrim(comment, 'l y') FROM orders"); + assertQuery("select array[nationkey], array_constructor(comment) from nation"); + assertQuery("SELECT nationkey, bit_count(nationkey, 10) FROM nation ORDER BY 1"); + assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK%'"); + assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK#%' escape '#'"); + assertQuery("SELECT orderkey, date_trunc('year', from_unixtime(orderkey, '-03:00')), date_trunc('quarter', from_unixtime(orderkey, '+14:00')), " + + "date_trunc('month', from_unixtime(orderkey, '+03:00')), date_trunc('day', from_unixtime(orderkey, '-07:00')), " + + "date_trunc('hour', from_unixtime(orderkey, '-09:30')), date_trunc('minute', from_unixtime(orderkey, '+05:30')), " + + "date_trunc('second', from_unixtime(orderkey, '+00:00')) FROM orders"); + } + + @Test + public void testAggregateFunctions() + { + assertQuery("select corr(nationkey, nationkey) from nation"); + assertQuery("select count(comment) from orders"); + assertQuery("select count(*) from nation"); + assertQuery("select count(abs(orderkey) between 1 and 60000) from orders group by orderkey"); + assertQuery("SELECT count(orderkey) FROM orders WHERE orderkey < 0 GROUP BY GROUPING SETS (())"); + // tinyint + assertQuery("SELECT sum(cast(linenumber as tinyint)), sum(cast(linenumber as tinyint)) FROM lineitem"); + // smallint + assertQuery("SELECT sum(cast(linenumber as smallint)), sum(cast(linenumber as smallint)) FROM lineitem"); + // integer + assertQuery("SELECT sum(linenumber), sum(linenumber) FROM lineitem"); + // bigint + assertQuery("SELECT sum(orderkey), sum(orderkey) FROM lineitem"); + // real + assertQuery("SELECT sum(tax_as_real), sum(tax_as_real) FROM lineitem"); + // double + assertQuery("SELECT sum(quantity), sum(quantity) FROM lineitem"); + // date + assertQuery("SELECT approx_distinct(orderdate, 0.023) FROM orders"); + // timestamp + assertQuery("SELECT approx_distinct(CAST(orderdate AS TIMESTAMP)) FROM orders"); + assertQuery("SELECT approx_distinct(CAST(orderdate AS TIMESTAMP), 0.023) FROM orders"); + assertQuery("SELECT checksum(from_unixtime(orderkey, '+01:00')) FROM lineitem WHERE orderkey < 20"); + assertQuerySucceeds("SELECT shuffle(array_sort(quantities)) FROM orders_ex"); + assertQuery("SELECT array_sort(shuffle(quantities)) FROM orders_ex"); + assertQuery("SELECT orderkey, array_sort(reduce_agg(linenumber, CAST(array[] as ARRAY(INTEGER)), (s, x) -> s || x, (s, s2) -> s || s2)) FROM lineitem group by orderkey"); + } + + @Test + public void testWindowFunctions() + { + assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey, orderstatus) rn, * from orders) WHERE rn = 1"); + assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) SELECT * FROM t WHERE rn = 1"); + assertQuery("SELECT row_number() OVER (PARTITION BY orderdate ORDER BY orderdate) FROM orders"); + assertQuery("SELECT min(orderkey) OVER (PARTITION BY orderdate ORDER BY orderdate, totalprice) FROM orders"); + assertQuery("SELECT sum(rn) FROM (SELECT row_number() over() rn, * from orders) WHERE rn = 10"); + assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey) rn, * from orders) WHERE rn = 1"); + } + + @Test + public void testArraySort() + { + assertQueryFails("SELECT array_sort(quantities, (x, y) -> if (x < y, 1, if (x > y, -1, 0))) FROM orders_ex", + "line 1:31: Expected a lambda that takes 1 argument\\(s\\) but got 2"); + } + + @Test + public void testInformationSchemaTables() + { + assertQueryFails("select lower(table_name) from information_schema.tables " + + "where table_name = 'lineitem' or table_name = 'LINEITEM' ", + "Compiler failed"); + } +}