Skip to content

Commit

Permalink
[native] Integrate SPI with sidecar and add e2e native function valid…
Browse files Browse the repository at this point in the history
…ation tests
  • Loading branch information
pdabre12 authored and Pratik Joseph Dabre committed Dec 10, 2024
1 parent 81eae50 commit f7b8633
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public static Map<String, String> getNativeSidecarProperties()
{
return ImmutableMap.<String, String>builder()
.put("coordinator-sidecar-enabled", "true")
.put("list-built-in-functions-only", "false")
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ public static Optional<BiFunction<Integer, URI, Process>> getExternalWorkerLaunc

if (isCoordinatorSidecarEnabled) {
configProperties = format("%s%n" +
"native-sidecar=true%n", configProperties);
"native-sidecar=true%n" +
"presto.default-namespace=native.default%n", configProperties);
}

if (remoteFunctionServerUds.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,71 @@
*/
package com.facebook.presto.sidecar.functionNamespace;

import com.facebook.airlift.http.client.HttpClient;
import com.facebook.airlift.http.client.HttpUriBuilder;
import com.facebook.airlift.http.client.Request;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata;
import com.facebook.presto.functionNamespace.UdfFunctionSignatureMap;
import com.facebook.presto.sidecar.ForSidecarInfo;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.PrestoException;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;

import java.net.URI;
import java.util.List;
import java.util.Map;

import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
import static com.facebook.airlift.http.client.Request.Builder.prepareGet;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static java.util.Objects.requireNonNull;

public class NativeFunctionDefinitionProvider
implements FunctionDefinitionProvider
{
private static final Logger log = Logger.get(NativeFunctionDefinitionProvider.class);

private final JsonCodec<Map<String, List<JsonBasedUdfFunctionMetadata>>> nativeFunctionSignatureMapJsonCodec;
private final NodeManager nodeManager;
private final HttpClient httpClient;
private static final String FUNCTION_SIGNATURES_ENDPOINT = "/v1/functions";

@Inject
public NativeFunctionDefinitionProvider(JsonCodec<Map<String, List<JsonBasedUdfFunctionMetadata>>> nativeFunctionSignatureMapJsonCodec)
public NativeFunctionDefinitionProvider(
@ForSidecarInfo HttpClient httpClient,
JsonCodec<Map<String, List<JsonBasedUdfFunctionMetadata>>> nativeFunctionSignatureMapJsonCodec,
NodeManager nodeManager)
{
this.nativeFunctionSignatureMapJsonCodec = requireNonNull(nativeFunctionSignatureMapJsonCodec, "nativeFunctionSignatureMapJsonCodec is null");
this.nativeFunctionSignatureMapJsonCodec =
requireNonNull(nativeFunctionSignatureMapJsonCodec, "nativeFunctionSignatureMapJsonCodec is null");
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
this.httpClient = requireNonNull(httpClient, "typeManager is null");
}

@Override
public UdfFunctionSignatureMap getUdfDefinition(NodeManager nodeManager)
{
try {
throw new UnsupportedOperationException();
Request request = prepareGet().setUri(getSidecarLocation()).build();
Map<String, List<JsonBasedUdfFunctionMetadata>> nativeFunctionSignatureMap = httpClient.execute(request, createJsonResponseHandler(nativeFunctionSignatureMapJsonCodec));
return new UdfFunctionSignatureMap(ImmutableMap.copyOf(nativeFunctionSignatureMap));
}
catch (Exception e) {
throw new PrestoException(INVALID_ARGUMENTS, "Failed to get functions from sidecar.", e);
}
}

private URI getSidecarLocation()
{
Node sidecarNode = nodeManager.getSidecarNode();
return HttpUriBuilder.uriBuilder()
.scheme("http")
.host(sidecarNode.getHost())
.port(sidecarNode.getHostAndPort().getPort())
.appendPath(FUNCTION_SIGNATURES_ENDPOINT)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.bootstrap.Bootstrap;
import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutorsModule;
import com.facebook.presto.sidecar.NativeSidecarCommunicationModule;
import com.facebook.presto.sidecar.NativeSidecarPlugin;
import com.facebook.presto.spi.function.FunctionHandleResolver;
import com.facebook.presto.spi.function.FunctionNamespaceManager;
Expand Down Expand Up @@ -56,7 +57,8 @@ public FunctionNamespaceManager<?> create(String catalogName, Map<String, String
try {
Bootstrap app = new Bootstrap(
new NativeFunctionNamespaceManagerModule(catalogName, context.getNodeManager(), context.getFunctionMetadataManager()),
new NoopSqlFunctionExecutorsModule());
new NoopSqlFunctionExecutorsModule(),
new NativeSidecarCommunicationModule());

Injector injector = app
.doNotInitializeLogging()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
* limitations under the License.
*/
package com.facebook.presto.sidecar;

import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory;
import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory;
import com.facebook.presto.testing.QueryRunner;
import com.google.common.collect.ImmutableMap;

public class NativeSidecarPluginQueryRunnerUtils
{
Expand All @@ -23,5 +26,12 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner)
{
queryRunner.installCoordinatorPlugin(new NativeSidecarPlugin());
queryRunner.loadSessionPropertyProvider(NativeSystemSessionPropertyProviderFactory.NAME);
queryRunner.loadFunctionNamespaceManager(
NativeFunctionNamespaceManagerFactory.NAME,
"native",
ImmutableMap.of(
"default.namespace", "true",
"supported-function-languages", "CPP",
"function-implementation-type", "CPP"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sidecar.functionNamespace;

import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.MaterializedRow;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.Test;

import java.util.List;
import java.util.regex.Pattern;

import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrdersEx;
import static com.facebook.presto.sidecar.NativeSidecarPluginQueryRunnerUtils.setupNativeSidecarPlugin;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;

@Test(singleThreaded = true)
public class TestNativeSidecarPluginFunctionNamespaceManager
extends AbstractTestQueryFramework
{
private static final String REGEX_FUNCTION_NAMESPACE = "native.default.*";
private static final int FUNCTION_COUNT = 1199;

@Override
protected void createTables()
{
QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner();
createLineitem(queryRunner);
createNation(queryRunner);
createOrders(queryRunner);
createOrdersEx(queryRunner);
}

@Override
protected QueryRunner createQueryRunner()
throws Exception
{
DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.createQueryRunner(true, true);
setupNativeSidecarPlugin(queryRunner);
return queryRunner;
}

@Override
protected QueryRunner createExpectedQueryRunner()
throws Exception
{
return PrestoNativeQueryRunnerUtils.createJavaQueryRunner();
}

@Test
public void testShowFunctions()
{
@Language("SQL") String sql = "SHOW FUNCTIONS";
MaterializedResult actualResult = computeActual(sql);
List<MaterializedRow> actualRows = actualResult.getMaterializedRows();
assertEquals(actualRows.size(), FUNCTION_COUNT);
for (MaterializedRow actualRow : actualRows) {
List<Object> row = actualRow.getFields();
// No namespace should be present on the functionNames
String functionName = row.get(0).toString();
if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, functionName)) {
fail(format("Namespace match found for row: %s", row));
}

// function namespace should be present.
String fullFunctionName = row.get(5).toString();
if (Pattern.matches(REGEX_FUNCTION_NAMESPACE, fullFunctionName)) {
continue;
}
fail(format("No namespace match found for row: %s", row));
}
}

@Test
public void testGeneralQueries()
{
assertQuery("SELECT substr(comment, 1, 10), length(comment), trim(comment) FROM orders");
assertQuery("SELECT substr(comment, 1, 10), length(comment), ltrim(comment) FROM orders");
assertQuery("SELECT substr(comment, 1, 10), length(comment), rtrim(comment) FROM orders");
assertQuery("select lower(comment) from nation");
assertQuery("SELECT trim(comment, ' ns'), ltrim(comment, 'a b c'), rtrim(comment, 'l y') FROM orders");
assertQuery("select array[nationkey], array_constructor(comment) from nation");
assertQuery("SELECT nationkey, bit_count(nationkey, 10) FROM nation ORDER BY 1");
assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK%'");
assertQuery("SELECT * FROM lineitem WHERE shipinstruct like 'TAKE BACK#%' escape '#'");
assertQuery("SELECT orderkey, date_trunc('year', from_unixtime(orderkey, '-03:00')), date_trunc('quarter', from_unixtime(orderkey, '+14:00')), " +
"date_trunc('month', from_unixtime(orderkey, '+03:00')), date_trunc('day', from_unixtime(orderkey, '-07:00')), " +
"date_trunc('hour', from_unixtime(orderkey, '-09:30')), date_trunc('minute', from_unixtime(orderkey, '+05:30')), " +
"date_trunc('second', from_unixtime(orderkey, '+00:00')) FROM orders");
}

@Test
public void testAggregateFunctions()
{
assertQuery("select corr(nationkey, nationkey) from nation");
assertQuery("select count(comment) from orders");
assertQuery("select count(*) from nation");
assertQuery("select count(abs(orderkey) between 1 and 60000) from orders group by orderkey");
assertQuery("SELECT count(orderkey) FROM orders WHERE orderkey < 0 GROUP BY GROUPING SETS (())");
// tinyint
assertQuery("SELECT sum(cast(linenumber as tinyint)), sum(cast(linenumber as tinyint)) FROM lineitem");
// smallint
assertQuery("SELECT sum(cast(linenumber as smallint)), sum(cast(linenumber as smallint)) FROM lineitem");
// integer
assertQuery("SELECT sum(linenumber), sum(linenumber) FROM lineitem");
// bigint
assertQuery("SELECT sum(orderkey), sum(orderkey) FROM lineitem");
// real
assertQuery("SELECT sum(tax_as_real), sum(tax_as_real) FROM lineitem");
// double
assertQuery("SELECT sum(quantity), sum(quantity) FROM lineitem");
// date
assertQuery("SELECT approx_distinct(orderdate, 0.023) FROM orders");
// timestamp
assertQuery("SELECT approx_distinct(CAST(orderdate AS TIMESTAMP)) FROM orders");
assertQuery("SELECT approx_distinct(CAST(orderdate AS TIMESTAMP), 0.023) FROM orders");
assertQuery("SELECT checksum(from_unixtime(orderkey, '+01:00')) FROM lineitem WHERE orderkey < 20");
assertQuerySucceeds("SELECT shuffle(array_sort(quantities)) FROM orders_ex");
assertQuery("SELECT array_sort(shuffle(quantities)) FROM orders_ex");
assertQuery("SELECT orderkey, array_sort(reduce_agg(linenumber, CAST(array[] as ARRAY(INTEGER)), (s, x) -> s || x, (s, s2) -> s || s2)) FROM lineitem group by orderkey");
}

@Test
public void testWindowFunctions()
{
assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey, orderstatus) rn, * from orders) WHERE rn = 1");
assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) SELECT * FROM t WHERE rn = 1");
assertQuery("SELECT row_number() OVER (PARTITION BY orderdate ORDER BY orderdate) FROM orders");
assertQuery("SELECT min(orderkey) OVER (PARTITION BY orderdate ORDER BY orderdate, totalprice) FROM orders");
assertQuery("SELECT sum(rn) FROM (SELECT row_number() over() rn, * from orders) WHERE rn = 10");
assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey) rn, * from orders) WHERE rn = 1");
}

@Test
public void testArraySort()
{
assertQueryFails("SELECT array_sort(quantities, (x, y) -> if (x < y, 1, if (x > y, -1, 0))) FROM orders_ex",
"line 1:31: Expected a lambda that takes 1 argument\\(s\\) but got 2");
}

@Test
public void testInformationSchemaTables()
{
assertQueryFails("select lower(table_name) from information_schema.tables "
+ "where table_name = 'lineitem' or table_name = 'LINEITEM' ",
"Compiler failed");
}
}

0 comments on commit f7b8633

Please sign in to comment.