From 9c6435ae99b54bc4fb536cb7220b31cefebea8ee Mon Sep 17 00:00:00 2001 From: Pratik Joseph Dabre Date: Wed, 8 May 2024 13:18:23 -0700 Subject: [PATCH] Added a new nativeFunctionHandle to fix the failing variadic functions --- ...actSqlInvokedFunctionNamespaceManager.java | 4 +- .../JsonBasedUdfFunctionMetadata.java | 25 +- .../prestissimo/NativeFunctionHandle.java | 96 +++++++ .../NativeFunctionNamespaceManager.java | 93 ++++++- ...NativeFunctionNamespaceManagerFactory.java | 2 +- .../NativeFunctionNamespaceManagerModule.java | 6 +- ...uiltInTypeAndFunctionNamespaceManager.java | 7 +- .../metadata/FunctionAndTypeManager.java | 26 +- .../metadata/SpecializedFunctionKey.java | 2 + ...ativeFunctionNamespaceManagerProvider.java | 18 +- .../presto/server/ServerMainModule.java | 5 +- .../server/testing/TestingPrestoServer.java | 1 - .../main/types/FunctionMetadata.cpp | 253 ++++++++++++++++++ .../main/types/tests/FunctionMetadataTest.cpp | 235 ++++++++++++++++ .../spi/function/FunctionMetadataManager.java | 2 + .../FunctionNamespaceManagerContext.java | 10 +- .../spi/function/SqlFunctionSupplier.java | 19 ++ .../spi/function/SqlInvokedFunction.java | 1 + 18 files changed, 771 insertions(+), 34 deletions(-) create mode 100644 presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionHandle.java create mode 100644 presto-native-execution/presto_cpp/main/types/FunctionMetadata.cpp create mode 100644 presto-native-execution/presto_cpp/main/types/tests/FunctionMetadataTest.cpp create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionSupplier.java diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java index f8701da6284b9..a33fc4c5ec24a 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java @@ -192,7 +192,7 @@ public Optional getUserDefinedType(QualifiedObjectName typeName } @Override - public final FunctionHandle getFunctionHandle(Optional transactionHandle, Signature signature) + public FunctionHandle getFunctionHandle(Optional transactionHandle, Signature signature) { checkCatalog(signature.getName()); // This is the only assumption in this class that we're dealing with sql-invoked regular function. @@ -245,7 +245,7 @@ public CompletableFuture executeFunction(String source, Funct typeManager.getType(functionMetadata.getReturnType())); } - private static PrestoException convertToPrestoException(UncheckedExecutionException exception, String failureMessage) + protected static PrestoException convertToPrestoException(UncheckedExecutionException exception, String failureMessage) { Throwable cause = exception.getCause(); if (cause instanceof PrestoException) { diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java index 1e7b54134834d..814370731b3d1 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.function.AggregationFunctionMetadata; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.TypeVariableConstraint; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -61,6 +62,14 @@ public class JsonBasedUdfFunctionMetadata * Optional Aggregate-specific metadata (required for aggregation functions) */ private final Optional aggregateMetadata; + /** + * Optional field: Marked to indicate the arity of the function. + */ + private final Optional variableArity; + /** + * Optional field: List of the typeVariableConstraints. + */ + private final Optional> typeVariableConstraints; @JsonCreator public JsonBasedUdfFunctionMetadata( @@ -70,7 +79,9 @@ public JsonBasedUdfFunctionMetadata( @JsonProperty("paramTypes") List paramTypes, @JsonProperty("schema") String schema, @JsonProperty("routineCharacteristics") RoutineCharacteristics routineCharacteristics, - @JsonProperty("aggregateMetadata") Optional aggregateMetadata) + @JsonProperty("aggregateMetadata") Optional aggregateMetadata, + @JsonProperty("variableArity") Optional variableArity, + @JsonProperty("typeVariableConstraints") Optional> typeVariableConstraints) { this.docString = requireNonNull(docString, "docString is null"); this.functionKind = requireNonNull(functionKind, "functionKind is null"); @@ -82,6 +93,8 @@ public JsonBasedUdfFunctionMetadata( checkArgument( (functionKind == AGGREGATE && aggregateMetadata.isPresent()) || (functionKind != AGGREGATE && !aggregateMetadata.isPresent()), "aggregateMetadata must be present for aggregation functions and absent otherwise"); + this.variableArity = requireNonNull(variableArity, "variableArity is null"); + this.typeVariableConstraints = requireNonNull(typeVariableConstraints, "typeVariableConstraints is null"); } public String getDocString() @@ -123,4 +136,14 @@ public Optional getAggregateMetadata() { return aggregateMetadata; } + + public Optional getVariableArity() + { + return variableArity; + } + + public Optional> getTypeVariableConstraints() + { + return typeVariableConstraints; + } } diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionHandle.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionHandle.java new file mode 100644 index 0000000000000..f0787ef699182 --- /dev/null +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionHandle.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functionNamespace.prestissimo; + +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlFunctionHandle; +import com.facebook.presto.spi.function.SqlFunctionId; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class NativeFunctionHandle + extends SqlFunctionHandle +{ + private final Signature signature; + + @JsonCreator + public NativeFunctionHandle(@JsonProperty("signature") Signature signature) + { + super(new SqlFunctionId(signature.getName(), signature.getArgumentTypes()), "1"); + this.signature = requireNonNull(signature, "signature is null"); + checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + } + + @JsonProperty + public Signature getSignature() + { + return signature; + } + + @Override + public String getName() + { + return signature.getName().toString(); + } + + @Override + public FunctionKind getKind() + { + return signature.getKind(); + } + + @Override + public CatalogSchemaName getCatalogSchemaName() + { + return signature.getName().getCatalogSchemaName(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NativeFunctionHandle that = (NativeFunctionHandle) o; + return Objects.equals(signature, that.signature); + } + + @Override + public int hashCode() + { + return Objects.hash(signature); + } + + @Override + public String toString() + { + return signature.toString(); + } + + private static void checkArgument(boolean condition, String message, Object... args) + { + if (!condition) { + throw new IllegalArgumentException(String.format(message, args)); + } + } +} diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java index 8ec120ce4467d..4d9bde7c5ea1d 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java @@ -31,17 +31,28 @@ import com.facebook.presto.spi.function.AlterRoutineCharacteristics; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.FunctionNamespaceTransactionHandle; import com.facebook.presto.spi.function.Parameter; import com.facebook.presto.spi.function.ScalarFunctionImplementation; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; +import com.facebook.presto.spi.function.SqlFunctionSupplier; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.TypeVariableConstraint; import com.google.common.base.Suppliers; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.UncheckedExecutionException; import javax.inject.Inject; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -60,6 +71,7 @@ import static java.lang.Long.parseLong; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.HOURS; public class NativeFunctionNamespaceManager extends AbstractSqlInvokedFunctionNamespaceManager @@ -71,6 +83,8 @@ public class NativeFunctionNamespaceManager private final NodeManager nodeManager; private final Map latestFunctions = new ConcurrentHashMap<>(); private final Supplier> memoizedFunctionsSupplier; + private final FunctionMetadataManager functionMetadataManager; + private final LoadingCache specializedFunctionKeyCache; @Inject public NativeFunctionNamespaceManager( @@ -78,13 +92,24 @@ public NativeFunctionNamespaceManager( SqlFunctionExecutors sqlFunctionExecutors, SqlInvokedFunctionNamespaceManagerConfig config, FunctionDefinitionProvider functionDefinitionProvider, - NodeManager nodeManager) + NodeManager nodeManager, + FunctionMetadataManager functionMetadataManager) { super(catalogName, sqlFunctionExecutors, config); this.functionDefinitionProvider = requireNonNull(functionDefinitionProvider, "functionDefinitionProvider is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.memoizedFunctionsSupplier = Suppliers.memoizeWithExpiration(this::bootstrapNamespace, config.getFunctionCacheExpiration().toMillis(), TimeUnit.MILLISECONDS); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.specializedFunctionKeyCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS) + .build(CacheLoader.from(this::doGetSpecializedFunctionKey)); + } + + private SqlFunctionSupplier doGetSpecializedFunctionKey(Signature signature) + { + return functionMetadataManager.getSpecializedFunctionKey(signature); } private Map bootstrapNamespace() @@ -137,6 +162,7 @@ private static SqlInvokedFunction copyFunction(SqlInvokedFunction function) return new SqlInvokedFunction( function.getSignature().getName(), function.getParameters(), + function.getSignature().getTypeVariableConstraints(), function.getSignature().getReturnType(), function.getDescription(), function.getRoutineCharacteristics(), @@ -152,15 +178,28 @@ protected SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonB QualifiedObjectName qualifiedFunctionName = QualifiedObjectName.valueOf(new CatalogSchemaName(getCatalogName(), jsonBasedUdfFunctionMetaData.getSchema()), functionName); List parameterNameList = jsonBasedUdfFunctionMetaData.getParamNames(); List parameterTypeList = jsonBasedUdfFunctionMetaData.getParamTypes(); - + List typeVariableConstraintsList = jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().isPresent() ? + jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().get() : Collections.emptyList(); ImmutableList.Builder parameterBuilder = ImmutableList.builder(); for (int i = 0; i < parameterNameList.size(); i++) { parameterBuilder.add(new Parameter(parameterNameList.get(i), parameterTypeList.get(i))); } + // Todo: Clean this method up + ImmutableList.Builder typeVariableConstraintsBuilder = ImmutableList.builder(); + for (TypeVariableConstraint typeVariableConstraint : typeVariableConstraintsList) { + typeVariableConstraintsBuilder.add(new TypeVariableConstraint( + typeVariableConstraint.getName(), + typeVariableConstraint.isComparableRequired(), + typeVariableConstraint.isOrderableRequired(), + null, + typeVariableConstraint.isNonDecimalNumericRequired())); + } + return new SqlInvokedFunction( qualifiedFunctionName, parameterBuilder.build(), + typeVariableConstraintsBuilder.build(), jsonBasedUdfFunctionMetaData.getOutputType(), jsonBasedUdfFunctionMetaData.getDocString(), jsonBasedUdfFunctionMetaData.getRoutineCharacteristics(), @@ -188,10 +227,13 @@ protected UserDefinedType fetchUserDefinedTypeDirect(QualifiedObjectName typeNam @Override protected FunctionMetadata fetchFunctionMetadataDirect(SqlFunctionHandle functionHandle) { + if (functionHandle instanceof NativeFunctionHandle) { + return getMetadataFromNativeFunctionHandle(functionHandle); + } + return fetchFunctionsDirect(functionHandle.getFunctionId().getFunctionName()).stream() .filter(function -> function.getRequiredFunctionHandle().equals(functionHandle)) - .map(this::sqlInvokedFunctionToMetadata) - .collect(onlyElement()); + .map(this::sqlInvokedFunctionToMetadata).collect(onlyElement()); } @Override @@ -248,4 +290,47 @@ public void addUserDefinedType(UserDefinedType userDefinedType) name); userDefinedTypes.put(name, userDefinedType); } + + @Override + public final FunctionHandle getFunctionHandle(Optional transactionHandle, Signature signature) + { + FunctionHandle functionHandle = super.getFunctionHandle(transactionHandle, signature); + + // only handle variadic signatures here , for normal signature we use the AbstractSqlInvokedFunctionNamespaceManager function handle. + if (functionHandle == null) { + return new NativeFunctionHandle(signature); + } + return functionHandle; + } + + private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle) + { + NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle; + Signature signature = nativeFunctionHandle.getSignature(); + SqlFunctionSupplier functionKey; + try { + functionKey = specializedFunctionKeyCache.getUnchecked(signature); + } + catch (UncheckedExecutionException e) { + throw convertToPrestoException(e, format("Error getting FunctionMetadata for handle: %s", functionHandle)); + } + SqlFunction function = functionKey.getFunction(); + + // todo: verify this metadata return + SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; + return new FunctionMetadata( + signature.getName(), + signature.getArgumentTypes(), + sqlFunction.getParameters().stream() + .map(Parameter::getName) + .collect(toImmutableList()), + signature.getReturnType(), + function.getSignature().getKind(), + sqlFunction.getRoutineCharacteristics().getLanguage(), + getFunctionImplementationType(sqlFunction), + function.isDeterministic(), + function.isCalledOnNullInput(), + sqlFunction.getVersion(), + function.getComplexTypeFunctionDescriptor()); + } } diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManagerFactory.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManagerFactory.java index be13e48f43a1f..1f8c5fd52416e 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManagerFactory.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManagerFactory.java @@ -55,7 +55,7 @@ public FunctionNamespaceManager create(String catalogName, Map candidates = getFunctions(null, signature.getName()); + Collection candidates = getFunctions(null, signature.getName()); + return doGetSpecializedFunctionKey(signature, candidates); + } + + public SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature, Collection candidates) + { // search for exact match Type returnType = functionAndTypeManager.getType(signature.getReturnType()); List argumentTypeSignatureProviders = fromTypeSignatures(signature.getArgumentTypes()); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index a94fc56d6773d..0a4f3743198d7 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -50,6 +50,7 @@ import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.function.SqlFunctionId; +import com.facebook.presto.spi.function.SqlFunctionSupplier; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver; @@ -215,6 +216,12 @@ public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) return FunctionAndTypeManager.this.getFunctionMetadata(functionHandle); } + @Override + public SqlFunctionSupplier getSpecializedFunctionKey(Signature signature) + { + return FunctionAndTypeManager.this.getSpecializedFunctionKey(signature); + } + @Override public Collection listBuiltInFunctions() { @@ -257,7 +264,7 @@ public void loadFunctionNamespaceManager( requireNonNull(functionNamespaceManagerName, "functionNamespaceManagerName is null"); FunctionNamespaceManagerFactory factory = functionNamespaceManagerFactories.get(functionNamespaceManagerName); checkState(factory != null, "No factory for function namespace manager %s", functionNamespaceManagerName); - FunctionNamespaceManager functionNamespaceManager = factory.create(catalogName, properties, new FunctionNamespaceManagerContext(this, nodeManager)); + FunctionNamespaceManager functionNamespaceManager = factory.create(catalogName, properties, new FunctionNamespaceManagerContext(this, nodeManager, Optional.of(this))); functionNamespaceManager.setBlockEncodingSerde(blockEncodingSerde); transactionManager.registerFunctionNamespaceManager(catalogName, functionNamespaceManager); if (functionNamespaceManagers.putIfAbsent(catalogName, functionNamespaceManager) != null) { @@ -618,17 +625,19 @@ public boolean nullIfSpecialFormEnabled() public FunctionHandle lookupFunction(String name, List parameterTypes) { QualifiedObjectName functionName = qualifyObjectName(QualifiedName.of(name)); + Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionName.getCatalogSchemaName()); + checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for function '%s'", functionName); if (parameterTypes.stream().noneMatch(TypeSignatureProvider::hasDependency)) { return lookupCachedFunction(functionName, parameterTypes); } - Collection candidates = builtInTypeAndFunctionNamespaceManager.getFunctions(Optional.empty(), functionName); + Collection candidates = functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); Optional match = functionSignatureMatcher.match(candidates, parameterTypes, false); if (!match.isPresent()) { throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, candidates)); } - return builtInTypeAndFunctionNamespaceManager.getFunctionHandle(Optional.empty(), match.get()); + return functionNamespaceManager.get().getFunctionHandle(Optional.empty(), match.get()); } public FunctionHandle lookupCast(CastType castType, Type fromType, Type toType) @@ -746,6 +755,17 @@ private Optional> getServingFunc return Optional.ofNullable(functionNamespaceManagers.get(typeSignatureBase.getTypeName().getCatalogName())); } + @Override + @SuppressWarnings("unchecked") + public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) + { + QualifiedObjectName functionName = signature.getName(); + Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionName.getCatalogSchemaName()); + checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for signature '%s'", signature.getName()); + Collection candidates = (Collection) functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); + return builtInTypeAndFunctionNamespaceManager.doGetSpecializedFunctionKey(signature, candidates); + } + private static class FunctionResolutionCacheKey { private final QualifiedObjectName functionName; diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SpecializedFunctionKey.java b/presto-main/src/main/java/com/facebook/presto/metadata/SpecializedFunctionKey.java index 62e052af08af4..f81018a760670 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SpecializedFunctionKey.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SpecializedFunctionKey.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.SqlFunctionSupplier; import java.util.Objects; @@ -22,6 +23,7 @@ import static java.util.Objects.requireNonNull; public class SpecializedFunctionKey + implements SqlFunctionSupplier { private final SqlFunction function; private final BoundVariables boundVariables; diff --git a/presto-main/src/main/java/com/facebook/presto/server/NativeFunctionNamespaceManagerProvider.java b/presto-main/src/main/java/com/facebook/presto/server/NativeFunctionNamespaceManagerProvider.java index 0dbf4c58d2beb..43558ba79b049 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/NativeFunctionNamespaceManagerProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/server/NativeFunctionNamespaceManagerProvider.java @@ -13,40 +13,24 @@ */ package com.facebook.presto.server; -import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.NodeManager; import com.google.inject.Inject; -import java.util.Map; -import java.util.Optional; - import static java.util.Objects.requireNonNull; public class NativeFunctionNamespaceManagerProvider { private final NodeManager nodeManager; - private final Metadata metadata; - @Inject public NativeFunctionNamespaceManagerProvider( - NodeManager nodeManager, - Metadata metadata) + NodeManager nodeManager) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); - this.metadata = requireNonNull(metadata, "metadata is null"); } public NodeManager getNodeManager() { return nodeManager; } - - public void loadNativeFunctionNamespaceManager( - String nativeFunctionNamespaceManagerName, - String catalogName, - Map properties) - { - metadata.getFunctionAndTypeManager().loadFunctionNamespaceManager(nativeFunctionNamespaceManagerName, catalogName, properties, Optional.ofNullable(getNodeManager())); - } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index e47eda7f38433..4a2394a8d7a10 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -106,6 +106,7 @@ import com.facebook.presto.metadata.StaticFunctionNamespaceStore; import com.facebook.presto.metadata.StaticFunctionNamespaceStoreConfig; import com.facebook.presto.metadata.TablePropertyManager; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.ExchangeClientConfig; import com.facebook.presto.operator.ExchangeClientFactory; import com.facebook.presto.operator.ExchangeClientSupplier; @@ -782,8 +783,8 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon newOptionalBinder(binder, NodeStatusService.class); binder.bind(NodeStatusNotificationManager.class).in(Scopes.SINGLETON); //native-function-namespace-manager binder - binder.bind(NativeFunctionNamespaceManagerProvider.class); - binder.bind(NodeManager.class).to(ConnectorAwareNodeManager.class).in(Scopes.SINGLETON); + binder.bind(NodeManager.class).to(PluginNodeManager.class).in(Scopes.SINGLETON); + binder.bind(NativeFunctionNamespaceManagerProvider.class).in(Scopes.SINGLETON); } @Provides diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 3df1c3a68f2ef..f04b286e8edd7 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -173,7 +173,6 @@ public class TestingPrestoServer private final boolean coordinator; private final boolean nodeSchedulerIncludeCoordinator; private final ServerInfoResource serverInfoResource; - private final NativeFunctionNamespaceManagerProvider nativeFunctionNamespaceManagerProvider; private final ResourceManagerClusterStateProvider clusterStateProvider; diff --git a/presto-native-execution/presto_cpp/main/types/FunctionMetadata.cpp b/presto-native-execution/presto_cpp/main/types/FunctionMetadata.cpp new file mode 100644 index 0000000000000..da6036e4cfccf --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/FunctionMetadata.cpp @@ -0,0 +1,253 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/types/FunctionMetadata.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +namespace facebook::presto { + +namespace { + +constexpr char const* kDefaultSchema = "default"; + +// The keys in velox function maps are of the format +// catalog.schema.function_name. This utility function extracts the +// function_name and the function visibility from this string. + +// Note: A function is considered hidden if it is not found within a +// function namespace. +std::pair +getFunctionNameAndVisibility(const std::string& registeredFunctionName) { + std::vector pieces; + folly::split('.', registeredFunctionName, pieces, true); + protocol::SqlFunctionVisibility functionVisibility = pieces.size() == 1 + ? protocol::SqlFunctionVisibility::HIDDEN + : protocol::SqlFunctionVisibility::PUBLIC; + return std::make_pair(pieces.back(), functionVisibility); +} + +const protocol::AggregationFunctionMetadata getAggregationFunctionMetadata( + const AggregateFunctionSignature& aggregateFunctionSignature) { + protocol::AggregationFunctionMetadata aggregationFunctionMetadata; + aggregationFunctionMetadata.intermediateType = + aggregateFunctionSignature.intermediateType().toString(); + /// TODO: Set to true for now. To be read from an existing mapping of + /// aggregate to order sensitivity which needs to be added. + aggregationFunctionMetadata.isOrderSensitive = true; + return aggregationFunctionMetadata; +} + +const protocol::RoutineCharacteristics getRoutineCharacteristics( + const FunctionSignature& functionSignature, + const std::string& functionName, + const protocol::FunctionKind& functionKind) { + protocol::Determinism determinism; + protocol::NullCallClause nullCallClause; + if (functionKind == protocol::FunctionKind::SCALAR) { + auto functionMetadata = + getFunctionMetadata(functionName, functionSignature); + determinism = functionMetadata.isDeterministic + ? protocol::Determinism::DETERMINISTIC + : protocol::Determinism::NOT_DETERMINISTIC; + nullCallClause = functionMetadata.isDefaultNullBehavior + ? protocol::NullCallClause::RETURNS_NULL_ON_NULL_INPUT + : protocol::NullCallClause::CALLED_ON_NULL_INPUT; + } else { + // Default metadata values of DETERMINISTIC and CALLED_ON_NULL_INPUT for + // non-scalar functions. + determinism = protocol::Determinism::DETERMINISTIC; + nullCallClause = protocol::NullCallClause::CALLED_ON_NULL_INPUT; + } + + protocol::RoutineCharacteristics routineCharacteristics; + routineCharacteristics.language = + std::make_shared(protocol::Language({"CPP"})); + routineCharacteristics.determinism = + std::make_shared(determinism); + routineCharacteristics.nullCallClause = + std::make_shared(nullCallClause); + return routineCharacteristics; +} + +const std::vector getTypeVariableConstraints( + const FunctionSignature& functionSignature) { + std::vector typeVariableConstraints; + const auto functionVariables = functionSignature.variables(); + for (const auto& [name, signature] : functionVariables) { + if (signature.isTypeParameter()) { + protocol::TypeVariableConstraint typeVariableConstraint; + typeVariableConstraint.name = signature.name(); + typeVariableConstraint.orderableRequired = signature.orderableTypesOnly(); + typeVariableConstraint.comparableRequired = + signature.comparableTypesOnly(); + typeVariableConstraints.emplace_back(typeVariableConstraint); + } + } + return typeVariableConstraints; +} + +void updateFunctionMetadata( + const std::string& functionName, + const FunctionSignature& functionSignature, + protocol::JsonBasedUdfFunctionMetadata& jsonBasedUdfFunctionMetadata) { + jsonBasedUdfFunctionMetadata.docString = functionName; + jsonBasedUdfFunctionMetadata.schema = kDefaultSchema; + jsonBasedUdfFunctionMetadata.outputType = + functionSignature.returnType().toString(); + jsonBasedUdfFunctionMetadata.variableArity = + std::make_shared(functionSignature.variableArity()); + const std::vector argumentTypes = + functionSignature.argumentTypes(); + std::vector paramTypes; + for (const auto& argumentType : argumentTypes) { + paramTypes.emplace_back(argumentType.toString()); + } + jsonBasedUdfFunctionMetadata.paramTypes = paramTypes; + jsonBasedUdfFunctionMetadata.typeVariableConstraints = + std::make_shared>( + getTypeVariableConstraints(functionSignature)); +} + +const std::vector +getAggregateFunctionMetadata( + const std::string& functionName, + const std::vector& + aggregateFunctionSignatures) { + std::vector + jsonBasedUdfFunctionMetadataList; + jsonBasedUdfFunctionMetadataList.reserve(aggregateFunctionSignatures.size()); + const protocol::FunctionKind functionKind = protocol::FunctionKind::AGGREGATE; + + for (const auto& aggregateFunctionSignature : aggregateFunctionSignatures) { + protocol::JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetadata; + jsonBasedUdfFunctionMetadata.functionKind = functionKind; + jsonBasedUdfFunctionMetadata.routineCharacteristics = + getRoutineCharacteristics( + *aggregateFunctionSignature, functionName, functionKind); + jsonBasedUdfFunctionMetadata.aggregateMetadata = + std::make_shared( + getAggregationFunctionMetadata(*aggregateFunctionSignature)); + + updateFunctionMetadata( + functionName, + *aggregateFunctionSignature, + jsonBasedUdfFunctionMetadata); + jsonBasedUdfFunctionMetadataList.emplace_back(jsonBasedUdfFunctionMetadata); + } + return jsonBasedUdfFunctionMetadataList; +} + +const std::vector +getScalarFunctionMetadata( + const std::string& functionName, + const std::vector& functionSignatures) { + std::vector + jsonBasedUdfFunctionMetadataList; + jsonBasedUdfFunctionMetadataList.reserve(functionSignatures.size()); + const protocol::FunctionKind functionKind = protocol::FunctionKind::SCALAR; + + for (const auto& functionSignature : functionSignatures) { + protocol::JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetadata; + jsonBasedUdfFunctionMetadata.functionKind = functionKind; + jsonBasedUdfFunctionMetadata.routineCharacteristics = + getRoutineCharacteristics( + *functionSignature, functionName, functionKind); + + updateFunctionMetadata( + functionName, *functionSignature, jsonBasedUdfFunctionMetadata); + jsonBasedUdfFunctionMetadataList.emplace_back(jsonBasedUdfFunctionMetadata); + } + return jsonBasedUdfFunctionMetadataList; +} + +const std::vector +getWindowFunctionMetadata( + const std::string& functionName, + const std::vector& windowFunctionSignatures) { + std::vector + jsonBasedUdfFunctionMetadataList; + jsonBasedUdfFunctionMetadataList.reserve(windowFunctionSignatures.size()); + const protocol::FunctionKind functionKind = protocol::FunctionKind::WINDOW; + + for (const auto& windowFunctionSignature : windowFunctionSignatures) { + protocol::JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetadata; + jsonBasedUdfFunctionMetadata.functionKind = functionKind; + jsonBasedUdfFunctionMetadata.routineCharacteristics = + getRoutineCharacteristics( + *windowFunctionSignature, functionName, functionKind); + + updateFunctionMetadata( + functionName, *windowFunctionSignature, jsonBasedUdfFunctionMetadata); + jsonBasedUdfFunctionMetadataList.emplace_back(jsonBasedUdfFunctionMetadata); + } + return jsonBasedUdfFunctionMetadataList; +} + +const std::vector getFunctionMetadata( + const std::string& functionName) { + if (auto aggregateFunctionSignatures = + getAggregateFunctionSignatures(functionName)) { + return getAggregateFunctionMetadata( + functionName, aggregateFunctionSignatures.value()); + } else if ( + auto windowFunctionSignatures = + getWindowFunctionSignatures(functionName)) { + return getWindowFunctionMetadata( + functionName, windowFunctionSignatures.value()); + } else { + auto functionSignatures = getFunctionSignatures(); + if (functionSignatures.find(functionName) != functionSignatures.end()) { + return getScalarFunctionMetadata( + functionName, functionSignatures.at(functionName)); + } + } + VELOX_UNREACHABLE("Function kind cannot be determined"); +} + +} // namespace + +void getJsonMetadataForFunction( + const std::string& registeredFunctionName, + nlohmann::json& jsonMetadataList) { + auto functionMetadataList = getFunctionMetadata(registeredFunctionName); + auto [functionName, functionVisibility] = + getFunctionNameAndVisibility(registeredFunctionName); + for (auto& functionMetadata : functionMetadataList) { + functionMetadata.functionVisibility = + std::make_shared(functionVisibility); + jsonMetadataList[functionName].emplace_back(functionMetadata); + } +} + +json getJsonFunctionMetadata() { + auto registeredFunctionNames = functions::getSortedAggregateNames(); + auto scalarFunctionNames = functions::getSortedScalarNames(); + for (const auto& scalarFunction : scalarFunctionNames) { + registeredFunctionNames.emplace_back(scalarFunction); + } + auto windowFunctionNames = functions::getSortedWindowNames(); + for (const auto& windowFunction : windowFunctionNames) { + registeredFunctionNames.emplace_back(windowFunction); + } + + nlohmann::json j; + for (const auto& registeredFunctionName : registeredFunctionNames) { + getJsonMetadataForFunction(registeredFunctionName, j); + } + return j; +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/types/tests/FunctionMetadataTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/FunctionMetadataTest.cpp new file mode 100644 index 0000000000000..e2cecd3e51319 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/FunctionMetadataTest.cpp @@ -0,0 +1,235 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" + +#include +#include "presto_cpp/main/types/FunctionMetadata.h" + +using namespace facebook::velox; +using namespace facebook::presto; + +class FunctionMetadataTest : public ::testing::Test {}; + +TEST(FunctionMetadataTest, testCount) { + aggregate::prestosql::registerAllAggregateFunctions(""); + std::string functionName = "count"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 2; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("functionKind"), protocol::FunctionKind::AGGREGATE); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + EXPECT_EQ(jsonAtIdx.at("schema"), "default"); + EXPECT_EQ(jsonAtIdx.at("outputType"), "bigint"); + auto paramTypes = jsonAtIdx.at("paramTypes"); + VELOX_CHECK( + ((paramTypes == std::vector({})) || + (paramTypes == std::vector({"T"})))); + + auto routineCharacteristics = jsonAtIdx.at("routineCharacteristics"); + EXPECT_EQ( + routineCharacteristics.at("language"), protocol::Language({"CPP"})); + EXPECT_EQ( + routineCharacteristics.at("determinism"), + protocol::Determinism::DETERMINISTIC); + EXPECT_EQ( + routineCharacteristics.at("nullCallClause"), + protocol::NullCallClause::CALLED_ON_NULL_INPUT); + + auto aggregateMetadata = jsonAtIdx.at("aggregateMetadata"); + EXPECT_EQ(aggregateMetadata.at("intermediateType"), "bigint"); + EXPECT_EQ(aggregateMetadata.at("isOrderSensitive"), true); + } +} + +TEST(FunctionMetadataTest, testSum) { + aggregate::prestosql::registerAllAggregateFunctions(""); + std::string functionName = "sum"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 7; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("functionKind"), protocol::FunctionKind::AGGREGATE); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + EXPECT_EQ(jsonAtIdx.at("schema"), "default"); + auto outputType = jsonAtIdx.at("outputType"); + VELOX_CHECK( + ((outputType == "real") || (outputType == "double") || + (outputType == "DECIMAL(38,a_scale)") || (outputType == "bigint"))); + auto paramTypes = jsonAtIdx.at("paramTypes"); + VELOX_CHECK( + ((paramTypes == std::vector({"real"})) || + (paramTypes == std::vector({"double"})) || + (paramTypes == + std::vector({"DECIMAL(a_precision,a_scale)"})) || + (paramTypes == std::vector({"tinyint"})) || + (paramTypes == std::vector({"smallint"})) || + (paramTypes == std::vector({"integer"})) || + (paramTypes == std::vector({"bigint"})))); + + auto routineCharacteristics = jsonAtIdx.at("routineCharacteristics"); + EXPECT_EQ( + routineCharacteristics.at("language"), protocol::Language({"CPP"})); + EXPECT_EQ( + routineCharacteristics.at("determinism"), + protocol::Determinism::DETERMINISTIC); + EXPECT_EQ( + routineCharacteristics.at("nullCallClause"), + protocol::NullCallClause::CALLED_ON_NULL_INPUT); + + auto aggregateMetadata = jsonAtIdx.at("aggregateMetadata"); + auto intermediateType = aggregateMetadata.at("intermediateType"); + VELOX_CHECK( + (intermediateType == "double") || (intermediateType == "VARBINARY") || + (intermediateType == "bigint")); + EXPECT_EQ(aggregateMetadata.at("isOrderSensitive"), true); + } +} + +TEST(FunctionMetadataTest, testRank) { + window::prestosql::registerAllWindowFunctions(""); + std::string functionName = "rank"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 1; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("functionKind"), protocol::FunctionKind::WINDOW); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + EXPECT_EQ(jsonAtIdx.at("schema"), "default"); + auto outputType = jsonAtIdx.at("outputType"); + VELOX_CHECK(((outputType == "integer") || (outputType == "bigint"))); + VELOX_CHECK(jsonAtIdx.at("paramTypes") == std::vector({})); + + auto routineCharacteristics = jsonAtIdx.at("routineCharacteristics"); + EXPECT_EQ( + routineCharacteristics.at("language"), protocol::Language({"CPP"})); + EXPECT_EQ( + routineCharacteristics.at("determinism"), + protocol::Determinism::DETERMINISTIC); + EXPECT_EQ( + routineCharacteristics.at("nullCallClause"), + protocol::NullCallClause::CALLED_ON_NULL_INPUT); + } +} + +TEST(FunctionMetadataTest, testRadians) { + functions::prestosql::registerArithmeticFunctions(""); + std::string functionName = "radians"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 1; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("functionKind"), protocol::FunctionKind::SCALAR); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + EXPECT_EQ(jsonAtIdx.at("schema"), "default"); + VELOX_CHECK(jsonAtIdx.at("outputType") == "double"); + VELOX_CHECK( + jsonAtIdx.at("paramTypes") == std::vector({"double"})); + + auto routineCharacteristics = jsonAtIdx.at("routineCharacteristics"); + EXPECT_EQ( + routineCharacteristics.at("language"), protocol::Language({"CPP"})); + EXPECT_EQ( + routineCharacteristics.at("determinism"), + protocol::Determinism::DETERMINISTIC); + EXPECT_EQ( + routineCharacteristics.at("nullCallClause"), + protocol::NullCallClause::RETURNS_NULL_ON_NULL_INPUT); + } +} + +TEST(FunctionMetadataTest, testFunctionVisibility) { + functions::prestosql::registerGeneralFunctions(""); + std::string functionName = "in"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 3; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + EXPECT_EQ(jsonAtIdx.at("functionVisibility"), "HIDDEN"); + } +} + +TEST(FunctionMetadataTest, testVariableArity) { + functions::prestosql::registerAllScalarFunctions(""); + std::string functionName = "array_constructor"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 2; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + // The `array_constructor` function has two signatures: `variableArity` is + // set to false for no parameters, and `variableArity` is set to `true` for + // accepting a variable number of arguments. + if (jsonAtIdx.at("paramTypes").size() > 0) { + EXPECT_EQ(jsonAtIdx.at("variableArity"), true); + } else { + EXPECT_EQ(jsonAtIdx.at("variableArity"), false); + } + } +} + +TEST(FunctionMetadataTest, testTypeVariableConstraints) { + functions::prestosql::registerAllScalarFunctions(""); + std::string functionName = "array_constructor"; + json jsonMetadata; + getJsonMetadataForFunction(functionName, jsonMetadata); + json j = jsonMetadata.at(functionName); + const auto numSignatures = 2; + EXPECT_EQ(j.size(), numSignatures); + + for (auto i = 0; i < numSignatures; i++) { + json jsonAtIdx = j.at(i); + EXPECT_EQ(jsonAtIdx.at("docString"), functionName); + auto typeVariableConstraints = jsonAtIdx.at("typeVariableConstraints"); + if (typeVariableConstraints.size() > 0) { + EXPECT_EQ(typeVariableConstraints[0].at("comparableRequired"), false); + EXPECT_EQ(typeVariableConstraints[0].at("name"), "T"); + EXPECT_EQ(typeVariableConstraints[0].at("orderableRequired"), false); + EXPECT_EQ( + typeVariableConstraints[0].at("nonDecimalNumericRequired"), false); + EXPECT_EQ(typeVariableConstraints[0].at("variadicBound"), ""); + } else { + EXPECT_EQ( + typeVariableConstraints, + std::vector({})); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadataManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadataManager.java index 04231b2c15557..d163ef952e64c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadataManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadataManager.java @@ -16,4 +16,6 @@ public interface FunctionMetadataManager { FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle); + + SqlFunctionSupplier getSpecializedFunctionKey(Signature signature); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionNamespaceManagerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionNamespaceManagerContext.java index 9736a7bff6589..c3b1fc8f31ecd 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionNamespaceManagerContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionNamespaceManagerContext.java @@ -24,11 +24,14 @@ public class FunctionNamespaceManagerContext { private final TypeManager typeManager; private final Optional nodeManager; + private final Optional functionMetadataManager; - public FunctionNamespaceManagerContext(TypeManager typeManager, Optional nodeManager) + public FunctionNamespaceManagerContext(TypeManager typeManager, Optional nodeManager, + Optional functionMetadataManager) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); } public TypeManager getTypeManager() @@ -40,4 +43,9 @@ public NodeManager getNodeManager() { return nodeManager.orElseThrow(() -> new IllegalArgumentException("nodeManager is not present")); } + + public FunctionMetadataManager getFunctionMetadataManager() + { + return functionMetadataManager.orElseThrow(() -> new IllegalArgumentException("functionMetadataManager is not present")); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionSupplier.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionSupplier.java new file mode 100644 index 0000000000000..fe3f8988e0548 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionSupplier.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +public interface SqlFunctionSupplier +{ + SqlFunction getFunction(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java index 17e720a580277..bb54437482e1b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java @@ -145,6 +145,7 @@ public SqlInvokedFunction withVersion(String version) return new SqlInvokedFunction( signature.getName(), parameters, + signature.getTypeVariableConstraints(), signature.getReturnType(), description, routineCharacteristics,