Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a native function namespace manager #23358

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.spi.function.AggregationFunctionMetadata;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
Expand All @@ -30,9 +31,6 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
* The function metadata provided by the Json file to the {@link JsonFileBasedFunctionNamespaceManager}.
pdabre12 marked this conversation as resolved.
Show resolved Hide resolved
*/
public class JsonBasedUdfFunctionMetadata
{
/**
Expand Down Expand Up @@ -64,6 +62,14 @@ public class JsonBasedUdfFunctionMetadata
* Optional Aggregate-specific metadata (required for aggregation functions)
*/
private final Optional<AggregationFunctionMetadata> aggregateMetadata;
/**
* Marked to indicate the arity of the function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think variableArity means that it can have a variable number of arguments.

*/
private final boolean variableArity;
/**
* Optional list of the typeVariableConstraints.
*/
private final Optional<List<TypeVariableConstraint>> typeVariableConstraints;

@JsonCreator
public JsonBasedUdfFunctionMetadata(
Expand All @@ -72,19 +78,23 @@ public JsonBasedUdfFunctionMetadata(
@JsonProperty("outputType") TypeSignature outputType,
@JsonProperty("paramTypes") List<TypeSignature> paramTypes,
@JsonProperty("schema") String schema,
@JsonProperty("variableArity") boolean variableArity,
@JsonProperty("routineCharacteristics") RoutineCharacteristics routineCharacteristics,
@JsonProperty("aggregateMetadata") Optional<AggregationFunctionMetadata> aggregateMetadata)
@JsonProperty("aggregateMetadata") Optional<AggregationFunctionMetadata> aggregateMetadata,
@JsonProperty("typeVariableConstraints") Optional<List<TypeVariableConstraint>> typeVariableConstraints)
{
this.docString = requireNonNull(docString, "docString is null");
this.functionKind = requireNonNull(functionKind, "functionKind is null");
this.outputType = requireNonNull(outputType, "outputType is null");
this.paramTypes = ImmutableList.copyOf(requireNonNull(paramTypes, "paramTypes is null"));
this.schema = requireNonNull(schema, "schema is null");
this.variableArity = variableArity;
this.routineCharacteristics = requireNonNull(routineCharacteristics, "routineCharacteristics is null");
this.aggregateMetadata = requireNonNull(aggregateMetadata, "aggregateMetadata is null");
checkArgument(
(functionKind == AGGREGATE && aggregateMetadata.isPresent()) || (functionKind != AGGREGATE && !aggregateMetadata.isPresent()),
"aggregateMetadata must be present for aggregation functions and absent otherwise");
this.typeVariableConstraints = requireNonNull(typeVariableConstraints, "typeVariableConstraints is null");
}

public String getDocString()
Expand Down Expand Up @@ -117,6 +127,11 @@ public String getSchema()
return schema;
}

public boolean getVariableArity()
{
return variableArity;
}

public RoutineCharacteristics getRoutineCharacteristics()
{
return routineCharacteristics;
Expand All @@ -126,4 +141,9 @@ public Optional<AggregationFunctionMetadata> getAggregateMetadata()
{
return aggregateMetadata;
}

public Optional<List<TypeVariableConstraint>> getTypeVariableConstraints()
{
return typeVariableConstraints;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ const protocol::AggregationFunctionMetadata getAggregationFunctionMetadata(
const std::string& name,
const AggregateFunctionSignature& signature) {
protocol::AggregationFunctionMetadata metadata;
metadata.intermediateType = signature.intermediateType().toString();
metadata.intermediateType =
boost::algorithm::to_lower_copy(signature.intermediateType().toString());
metadata.isOrderSensitive =
getAggregateFunctionEntry(name)->metadata.orderSensitive;
return metadata;
Expand Down Expand Up @@ -140,6 +141,24 @@ const protocol::RoutineCharacteristics getRoutineCharacteristics(
return routineCharacteristics;
}

const std::vector<protocol::TypeVariableConstraint> getTypeVariableConstraints(
const FunctionSignature& functionSignature) {
std::vector<protocol::TypeVariableConstraint> typeVariableConstraints;
const auto functionVariables = functionSignature.variables();
for (const auto& [name, signature] : functionVariables) {
if (signature.isTypeParameter()) {
protocol::TypeVariableConstraint typeVariableConstraint;
typeVariableConstraint.name =
boost::algorithm::to_lower_copy(signature.name());
typeVariableConstraint.orderableRequired = signature.orderableTypesOnly();
typeVariableConstraint.comparableRequired =
signature.comparableTypesOnly();
typeVariableConstraints.emplace_back(typeVariableConstraint);
}
}
return typeVariableConstraints;
}

std::optional<protocol::JsonBasedUdfFunctionMetadata> buildFunctionMetadata(
const std::string& name,
const std::string& schema,
Expand All @@ -152,19 +171,25 @@ std::optional<protocol::JsonBasedUdfFunctionMetadata> buildFunctionMetadata(
if (!isValidPrestoType(signature.returnType())) {
return std::nullopt;
}
metadata.outputType = signature.returnType().toString();
metadata.outputType =
boost::algorithm::to_lower_copy(signature.returnType().toString());

const auto& argumentTypes = signature.argumentTypes();
std::vector<std::string> paramTypes(argumentTypes.size());
for (auto i = 0; i < argumentTypes.size(); i++) {
if (!isValidPrestoType(argumentTypes.at(i))) {
return std::nullopt;
}
paramTypes[i] = argumentTypes.at(i).toString();
paramTypes[i] =
boost::algorithm::to_lower_copy(argumentTypes.at(i).toString());
}
metadata.paramTypes = paramTypes;
metadata.schema = schema;
metadata.variableArity = signature.variableArity();
metadata.routineCharacteristics = getRoutineCharacteristics(name, kind);
metadata.typeVariableConstraints =
std::make_shared<std::vector<protocol::TypeVariableConstraint>>(
getTypeVariableConstraints(signature));

if (aggregateSignature) {
metadata.aggregateMetadata =
Expand Down Expand Up @@ -199,8 +224,22 @@ json buildAggregateMetadata(
getWindowFunctionSignatures(name).has_value(),
"Aggregate function {} not registered as a window function",
name);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: understand this better. is there any sideeffect for the worker? will it be able to process aggregate as window function?

Copy link
Contributor Author

@pdabre12 pdabre12 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about the sideeffect for the worker.
It doesn't seem to have a problem in processing aggregate as window functions.
Locally, I extended all the tests cases from AbstractTestNativeWindowQueries and all the test cases are passing.

CC : @aditi-pandit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure of the context here... but all aggregate functions can be used as window functions in SQL queries. With this change seems like you are not building function metadata for aggregate functions as window functions from native side-car. So only pure window functions like rank, nth_value etc will appear as window functions from the side-car. Velox expects registration of each aggregate function as a window function in its registry. Does that not hold on the Presto Java side ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so for presto, i think any aggregate functions can be used as a window function, and we don't need to register it separately. Why doesn't this change break things for prestissimo window functions if we aren't registering them anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are still registering prestissimo window functions, one such example is cume_dist which shows up in the SHOW FUNCTIONS call.

presto:tpch> show functions like 'cume_dist';
 Function  | Return Type | Argument Types | Function Type | Deterministic |       Description        | Variable Arity | Built In | Temporary | Language 
-----------+-------------+----------------+---------------+---------------+--------------------------+----------------+----------+-----------+----------
 cume_dist | double      |                | window        | true          | native.default.cume_dist | false          | false    | false     | cpp 

We aren't registering aggregate functions as window functions, so we are not building the function metadata for those. I think this works similar to how aggregate functions used as window functions work in Presto.

Whenever a SQL query has aggregate functions used as window functions, we still resolve the window function metadata(which is an aggregate function) as we would have handled the aggregate function, that's why the queries don't fail.

// The functions returned by this endpoint are stored as SqlInvokedFunction
// objects, with SqlFunctionId serving as the primary key. SqlFunctionId is
// derived from both the functionName and argumentTypes parameters. Returning
// the same function twice—once as an aggregate function and once as a window
// function introduces ambiguity, as functionKind is not a component of
// SqlFunctionId. For any aggregate function utilized as a window function,
// the function’s metadata can be obtained from the associated aggregate
// function implementation for further processing. For additional information,
// refer to the following: •
// https://github.com/prestodb/presto/blob/master/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionId.java
// •
// https://github.com/prestodb/presto/blob/master/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java

const std::vector<protocol::FunctionKind> kinds = {
protocol::FunctionKind::AGGREGATE, protocol::FunctionKind::WINDOW};
protocol::FunctionKind::AGGREGATE};
json j = json::array();
json tj;
for (const auto& kind : kinds) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class FunctionMetadataTest : public ::testing::Test {
};

TEST_F(FunctionMetadataTest, approxMostFrequent) {
testFunction("approx_most_frequent", "ApproxMostFrequent.json", 12);
testFunction("approx_most_frequent", "ApproxMostFrequent.json", 6);
}

TEST_F(FunctionMetadataTest, arrayFrequency) {
Expand All @@ -73,13 +73,17 @@ TEST_F(FunctionMetadataTest, combinations) {
}

TEST_F(FunctionMetadataTest, covarSamp) {
testFunction("covar_samp", "CovarSamp.json", 4);
testFunction("covar_samp", "CovarSamp.json", 2);
}

TEST_F(FunctionMetadataTest, elementAt) {
testFunction("element_at", "ElementAt.json", 3);
}

TEST_F(FunctionMetadataTest, greatest) {
testFunction("greatest", "Greatest.json", 13);
}

TEST_F(FunctionMetadataTest, lead) {
testFunction("lead", "Lead.json", 3);
}
Expand All @@ -89,17 +93,17 @@ TEST_F(FunctionMetadataTest, ntile) {
}

TEST_F(FunctionMetadataTest, setAgg) {
testFunction("set_agg", "SetAgg.json", 2);
testFunction("set_agg", "SetAgg.json", 1);
}

TEST_F(FunctionMetadataTest, stddevSamp) {
testFunction("stddev_samp", "StddevSamp.json", 10);
testFunction("stddev_samp", "StddevSamp.json", 5);
}

TEST_F(FunctionMetadataTest, transformKeys) {
testFunction("transform_keys", "TransformKeys.json", 1);
}

TEST_F(FunctionMetadataTest, variance) {
testFunction("variance", "Variance.json", 10);
testFunction("variance", "Variance.json", 5);
}
Loading