From 10bfc086db66ad5bf4c22f437d1105accdef97f8 Mon Sep 17 00:00:00 2001 From: Pramod Date: Tue, 10 Sep 2024 21:43:47 +0530 Subject: [PATCH] [native] Add expression optimization support in sidecar --- .../presto_cpp/main/CMakeLists.txt | 2 + .../presto_cpp/main/PrestoServer.cpp | 10 + .../presto_cpp/main/PrestoServer.h | 2 + .../presto_cpp/main/expression/CMakeLists.txt | 32 + .../expression/RowExpressionOptimizer.cpp | 709 ++++++++++++++++++ .../main/expression/RowExpressionOptimizer.h | 131 ++++ .../main/expression/tests/CMakeLists.txt | 26 + .../tests/RowExpressionOptimizerTest.cpp | 187 +++++ .../tests/data/SimpleExpressionsExpected.json | 17 + .../tests/data/SimpleExpressionsInput.json | 278 +++++++ .../tests/data/SpecialFormExpected.json | 17 + .../tests/data/SpecialFormInput.json | 193 +++++ .../main/types/PrestoToVeloxExpr.cpp | 35 +- .../presto_cpp/main/types/PrestoToVeloxExpr.h | 24 + 14 files changed, 1639 insertions(+), 24 deletions(-) create mode 100644 presto-native-execution/presto_cpp/main/expression/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.cpp create mode 100644 presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.h create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/RowExpressionOptimizerTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json create mode 100644 presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 8cac276cb185b..0d1941a9ad6bc 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(operators) add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) +add_subdirectory(expression) add_subdirectory(thrift) add_library( @@ -50,6 +51,7 @@ target_link_libraries( presto_function_metadata presto_http presto_operators + presto_expression_optimizer velox_aggregates velox_caching velox_common_base diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index f1c77215b1773..d4e361f4e33fd 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1475,6 +1475,16 @@ void PrestoServer::registerSidecarEndpoints() { proxygen::ResponseHandler* downstream) { http::sendOkResponse(downstream, getFunctionsMetadata()); }); + + rowExpressionOptimizer_ = + std::make_unique(); + httpServer_->registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + return rowExpressionOptimizer_->optimize(message, body, downstream); + }); } protocol::NodeStatus PrestoServer::fetchNodeStatus() { diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index 65c5944d75cfb..7c2ae83cc425e 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -25,6 +25,7 @@ #include "presto_cpp/main/PeriodicHeartbeatManager.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/PrestoServerOperations.h" +#include "presto_cpp/main/expression/RowExpressionOptimizer.h" #include "presto_cpp/main/types/VeloxPlanValidator.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/MemoryAllocator.h" @@ -287,6 +288,7 @@ class PrestoServer { std::string address_; std::string nodeLocation_; folly::SSLContextPtr sslContext_; + std::unique_ptr rowExpressionOptimizer_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/expression/CMakeLists.txt b/presto-native-execution/presto_cpp/main/expression/CMakeLists.txt new file mode 100644 index 0000000000000..14d7960c7ad09 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_expression_optimizer RowExpressionOptimizer.cpp) + +target_link_libraries( + presto_expression_optimizer + presto_type_converter + presto_types + presto_protocol + presto_http + velox_coverage_util + velox_parse_expression + velox_parse_parser + velox_presto_serializer + velox_serialization + velox_type_parser + ${FOLLY_WITH_DEPENDENCIES}) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.cpp b/presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.cpp new file mode 100644 index 0000000000000..d5d219685abf0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.cpp @@ -0,0 +1,709 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/expression/RowExpressionOptimizer.h" +#include +#include "presto_cpp/main/common/Utils.h" +#include "velox/common/encode/Base64.h" +#include "velox/exec/ExchangeQueue.h" +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/expression/ExprCompiler.h" +#include "velox/expression/FieldReference.h" +#include "velox/expression/VectorFunction.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace facebook::presto::expression { + +namespace { + +protocol::TypeSignature getTypeSignature(const TypePtr& type) { + std::string typeSignature; + if (type->parameters().empty()) { + typeSignature = type->toString(); + boost::algorithm::to_lower(typeSignature); + } else if (type->isDecimal()) { + typeSignature = type->toString(); + } else { + std::vector childTypes; + if (type->isRow()) { + typeSignature = "row("; + childTypes = asRowType(type)->children(); + } else if (type->isArray()) { + typeSignature = "array("; + childTypes = type->asArray().children(); + } else if (type->isMap()) { + typeSignature = "map("; + const auto mapType = type->asMap(); + childTypes = {mapType.keyType(), mapType.valueType()}; + } else { + VELOX_USER_FAIL("Invalid type {}", type->toString()); + } + + if (!childTypes.empty()) { + auto numChildren = childTypes.size(); + for (auto i = 0; i < numChildren - 1; i++) { + typeSignature += fmt::format("{},", getTypeSignature(childTypes[i])); + } + typeSignature += getTypeSignature(childTypes[numChildren - 1]); + } + typeSignature += ")"; + } + + return typeSignature; +} + +json toVariableReferenceExpression( + const std::shared_ptr& fieldReference) { + protocol::VariableReferenceExpression vexpr; + vexpr.name = fieldReference->name(); + vexpr._type = "variable"; + vexpr.type = getTypeSignature(fieldReference->type()); + json res; + protocol::to_json(res, vexpr); + + return res; +} + +bool isPrestoSpecialForm(const std::string& name) { + static const std::unordered_set kPrestoSpecialForms = { + "and", + "coalesce", + "if", + "in", + "is_null", + "or", + "switch", + "when", + "null_if"}; + return kPrestoSpecialForms.count(name) != 0; +} + +json::array_t getInputExpressions( + const std::vector>& body) { + std::ostringstream oss; + for (auto& buf : body) { + oss << std::string((const char*)buf->data(), buf->length()); + } + return json::parse(oss.str()); +} + +template +std::shared_ptr getConstantExpr( + const TypePtr& type, + const DecodedVector& decoded, + memory::MemoryPool* pool) { + std::shared_ptr constExpr = nullptr; + if constexpr ( + KIND == TypeKind::ROW || KIND == TypeKind::UNKNOWN || + KIND == TypeKind::ARRAY || KIND == TypeKind::MAP) { + VELOX_USER_FAIL("Invalid result type {}", type->toString()); + } else { + using T = typename TypeTraits::NativeType; + auto constVector = std::make_shared>( + pool, decoded.size(), decoded.isNullAt(0), type, decoded.valueAt(0)); + constExpr = std::make_shared(constVector); + } + return constExpr; +} +} // namespace + +// ValueBlock in ConstantExpression requires only the column from the serialized +// PrestoPage without the page header. +std::string RowExpressionConverter::getValueBlock(const VectorPtr& vector) { + std::ostringstream output; + serde_->serializeSingleColumn(vector, nullptr, pool_.get(), &output); + const auto serialized = output.str(); + const auto serializedSize = serialized.size(); + return encoding::Base64::encode(serialized.c_str(), serializedSize); +} + +std::shared_ptr +RowExpressionConverter::getConstantRowExpression( + const std::shared_ptr& constantExpr) { + protocol::ConstantExpression cexpr; + cexpr.type = getTypeSignature(constantExpr->type()); + cexpr.valueBlock.data = getValueBlock(constantExpr->value()); + return std::make_shared(cexpr); +} + +// TODO: Remove this once native plugin supports evaluation of current_user. +std::shared_ptr +RowExpressionConverter::getCurrentUser(const std::string& currentUser) { + protocol::ConstantExpression cexpr; + cexpr.type = getTypeSignature(VARCHAR()); + cexpr.valueBlock.data = getValueBlock( + BaseVector::createConstant(VARCHAR(), currentUser, 1, pool_.get())); + return std::make_shared(cexpr); +} + +json RowExpressionConverter::getRowConstructorSpecialForm( + const exec::ExprPtr& expr, + const json& input) { + json res; + res["@type"] = "special"; + res["form"] = "ROW_CONSTRUCTOR"; + res["returnType"] = getTypeSignature(expr->type()); + + res["arguments"] = json::array(); + auto exprInputs = expr->inputs(); + if (!exprInputs.empty()) { + for (const auto& exprInput : exprInputs) { + res["arguments"].push_back(veloxExprToRowExpression(exprInput, input)); + } + } else if ( + auto constantExpr = + std::dynamic_pointer_cast(expr)) { + auto value = constantExpr->value(); + auto* constVector = value->as>(); + auto* rowVector = constVector->valueVector()->as(); + auto type = asRowType(constantExpr->type()); + auto children = rowVector->children(); + auto size = children.size(); + + json j; + protocol::ConstantExpression cexpr; + for (auto i = 0; i < size; i++) { + cexpr.type = getTypeSignature(type->childAt(i)); + cexpr.valueBlock.data = getValueBlock(rowVector->childAt(i)); + protocol::to_json(j, cexpr); + res["arguments"].push_back(j); + } + } + + return res; +} + +// When the second value in the returned pair is true, the arguments for switch +// special form are returned. Otherwise, the switch expression has been +// simplified and the first value corresponding to the switch case that always +// evaluates to true is returned. +std::pair RowExpressionConverter::getSwitchSpecialFormArgs( + const exec::ExprPtr& expr, + const json& input) { + json::array_t inputArgs = input["arguments"]; + auto numArgs = inputArgs.size(); + json::array_t result = json::array(); + const std::vector exprInputs = expr->inputs(); + const auto numInputs = exprInputs.size(); + + auto getWhenSpecialForm = [&](const json::array_t& whenArgs, + const vector_size_t idx) -> json { + json when; + when["@type"] = "special"; + when["form"] = "WHEN"; + when["arguments"] = whenArgs; + when["returnType"] = getTypeSignature(exprInputs[idx + 1]->type()); + return when; + }; + + // The searched form of the conditional expression needs to be handled + // differently from the simple form. The searched form can be detected by the + // presence of a boolean value in the first argument. This default boolean + // argument is not present in the Velox switch expression, so it is added to + // the arguments of output switch expression unchanged. + if (inputArgs[0].at("@type") == "constant" && + inputArgs[0].at("type") == "boolean") { + result.emplace_back(inputArgs[0]); + for (auto i = 0; i < numInputs - 1; i += 2) { + const vector_size_t argsIdx = i / 2 + 1; + json::array_t inputWhenArgs = inputArgs[argsIdx].at("arguments"); + json::array_t whenArgs; + whenArgs.emplace_back( + veloxExprToRowExpression(exprInputs[i], inputWhenArgs[0])); + whenArgs.emplace_back( + veloxExprToRowExpression(exprInputs[i + 1], inputWhenArgs[1])); + + result.emplace_back(getWhenSpecialForm(whenArgs, i)); + } + } else { + // The case 'expression' in simple form of conditional cannot be inferred + // from Velox since it could evaluate all when clauses to true or false, so + // we get it from the input json. + result.emplace_back(inputArgs[0]); + for (auto i = 0; i < numInputs - 1; i += 2) { + json::array_t whenArgs; + const vector_size_t argsIdx = i / 2 + 1; + const auto& caseValue = exprInputs[i + 1]; + json::array_t inputWhenArgs = inputArgs[argsIdx].at("arguments"); + + if (exprInputs[i]->isConstant()) { + auto constantExpr = + std::dynamic_pointer_cast(exprInputs[i]); + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + if (result.size() == 1) { + // This is the first case statement that evaluates to true, so + // return the expression corresponding to this case. + return { + json::array( + {veloxExprToRowExpression(caseValue, inputWhenArgs[1])}), + false}; + } else { + // If the case has been constant folded to false in the Velox + // switch expression, we do not have access to the expression + // inputs in Velox anymore. So we return the corresponding + // argument from the input switch expression. + result.emplace_back(inputArgs[argsIdx]); + } + } else { + // Skip cases that evaluate to false from the output switch + // expression's arguments. + continue; + } + } else { + whenArgs.emplace_back(getConstantRowExpression(constantExpr)); + } + } else { + VELOX_USER_CHECK(!exprInputs[i]->inputs().empty()); + const auto& matchExpr = exprInputs[i]->inputs().back(); + whenArgs.emplace_back( + veloxExprToRowExpression(matchExpr, inputWhenArgs[0])); + } + + whenArgs.emplace_back( + veloxExprToRowExpression(caseValue, inputWhenArgs[1])); + result.emplace_back(getWhenSpecialForm(whenArgs, i)); + } + } + + // Else clause. + if (numInputs % 2 != 0) { + result.push_back(veloxExprToRowExpression( + exprInputs[numInputs - 1], inputArgs[numArgs - 1])); + } + return {result, true}; +} + +json RowExpressionConverter::getSpecialForm( + const exec::ExprPtr& expr, + const json& input) { + json res; + res["@type"] = "special"; + std::string form; + if (input.contains("form")) { + form = input["form"]; + } else { + // If input json is a call expression instead of a special form, for cases + // like 'is_null', the key 'form' will not be present in the input json. + form = expr->name(); + } + // Presto requires the field form to be in upper case. + std::transform(form.begin(), form.end(), form.begin(), ::toupper); + res["form"] = form; + auto exprInputs = expr->inputs(); + res["arguments"] = json::array(); + + // Arguments for switch expression include special form expression 'when' + // so it is constructed separately. If the switch expression evaluation found + // a case that always evaluates to true, the second value in pair switchResult + // will be false and the first value in pair will contain the value + // corresponding to the simplified case. + if (form == "SWITCH") { + auto switchResult = getSwitchSpecialFormArgs(expr, input); + if (switchResult.second) { + res["arguments"] = switchResult.first; + } else { + return switchResult.first.front(); + } + } else { + json::array_t inputArguments = input["arguments"]; + const auto numInputs = exprInputs.size(); + VELOX_USER_CHECK_LE(numInputs, inputArguments.size()); + for (auto i = 0; i < numInputs; i++) { + res["arguments"].push_back( + veloxExprToRowExpression(exprInputs[i], inputArguments[i])); + } + } + res["returnType"] = getTypeSignature(expr->type()); + + return res; +} + +json RowExpressionConverter::toConstantRowExpression( + const exec::ExprPtr& expr) { + json res; + auto constantExpr = std::dynamic_pointer_cast(expr); + VELOX_USER_CHECK_NOT_NULL(constantExpr); + auto constantRowExpr = getConstantRowExpression(constantExpr); + protocol::to_json(res, constantRowExpr); + return res; +} + +json RowExpressionConverter::toCallRowExpression( + const exec::ExprPtr& expr, + const json& input) { + json res; + res["@type"] = "call"; + protocol::Signature signature; + std::string exprName = expr->name(); + if (veloxToPrestoOperatorMap_.find(expr->name()) != + veloxToPrestoOperatorMap_.end()) { + exprName = veloxToPrestoOperatorMap_.at(expr->name()); + } + signature.name = exprName; + res["displayName"] = exprName; + signature.kind = protocol::FunctionKind::SCALAR; + signature.typeVariableConstraints = {}; + signature.longVariableConstraints = {}; + signature.returnType = getTypeSignature(expr->type()); + + std::vector argumentTypes; + auto exprInputs = expr->inputs(); + auto numArgs = exprInputs.size(); + argumentTypes.reserve(numArgs); + for (auto i = 0; i < numArgs; i++) { + argumentTypes.emplace_back(getTypeSignature(exprInputs[i]->type())); + } + signature.argumentTypes = argumentTypes; + signature.variableArity = false; + + protocol::BuiltInFunctionHandle builtInFunctionHandle; + builtInFunctionHandle._type = "$static"; + builtInFunctionHandle.signature = signature; + res["functionHandle"] = builtInFunctionHandle; + res["returnType"] = getTypeSignature(expr->type()); + res["arguments"] = json::array(); + for (const auto& exprInput : exprInputs) { + res["arguments"].push_back(veloxExprToRowExpression(exprInput, input)); + } + + return res; +} + +json RowExpressionConverter::veloxExprToRowExpression( + const exec::ExprPtr& expr, + const json& input) { + if (expr->type()->isRow()) { + // Velox constant expressions of ROW type map to special form expression + // row_constructor in Presto. + return getRowConstructorSpecialForm(expr, input); + } else if (expr->isConstant()) { + if (expr->inputs().empty()) { + return toConstantRowExpression(expr); + } else { + // Inputs to constant expressions are constant, eg: divide(1, 2). + return input; + } + } else if ( + auto field = + std::dynamic_pointer_cast(expr)) { + return toVariableReferenceExpression(field); + } else if (expr->isSpecialForm() || expr->vectorFunction()) { + // Check if special form expression or call expression. + auto exprName = expr->name(); + boost::algorithm::to_lower(exprName); + if (isPrestoSpecialForm(exprName)) { + return getSpecialForm(expr, input); + } else { + return toCallRowExpression(expr, input); + } + } + + VELOX_NYI( + "Conversion of Velox Expr {} to Presto RowExpression is not supported", + expr->toString()); +} + +exec::ExprPtr RowExpressionOptimizer::compileExpression( + const std::shared_ptr& inputRowExpr) { + auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr); + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + return compiledExprs[0]; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeAndSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto leftExpr = compileExpression(left); + bool isLeftNull; + + if (auto constantExpr = + std::dynamic_pointer_cast(leftExpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (!isLeftNull) { + if (auto constVector = + constantExpr->value()->as>()) { + if (!constVector->valueAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } else { + return right; + } + } + } + } + + auto rightExpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rightExpr)) { + if (isLeftNull && constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + if (auto constVector = constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return left; + } + return right; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeIfSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto condition = specialFormExpr->arguments[0]; + auto expr = compileExpression(condition); + + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (auto constVector = constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return specialFormExpr->arguments[1]; + } + return specialFormExpr->arguments[2]; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeIsNullSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto expr = compileExpression(specialFormExpr); + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeOrSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto leftExpr = compileExpression(left); + bool isLeftNull; + + if (auto constantExpr = + std::dynamic_pointer_cast(leftExpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (!isLeftNull) { + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + return right; + } + } + } + + auto rightExpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rightExpr)) { + if (isLeftNull && constantExpr->value()->isNullAt(0)) { + return rowExpressionConverter_.getConstantRowExpression(constantExpr); + } + if (auto constVector = constantExpr->value()->as>()) { + if (!constVector->valueAt(0)) { + return left; + } + return right; + } + } + + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeCoalesceSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr) { + auto argsNoNulls = specialFormExpr->arguments; + argsNoNulls.erase( + std::remove_if( + argsNoNulls.begin(), + argsNoNulls.end(), + [&](const auto& arg) { + auto compiledExpr = compileExpression(arg); + if (auto constantExpr = + std::dynamic_pointer_cast( + compiledExpr)) { + return constantExpr->value()->isNullAt(0); + } + return false; + }), + argsNoNulls.end()); + + if (argsNoNulls.empty()) { + return specialFormExpr->arguments[0]; + } + specialFormExpr->arguments = argsNoNulls; + return specialFormExpr; +} + +RowExpressionPtr RowExpressionOptimizer::optimizeSpecialForm( + const std::shared_ptr& specialFormExpr) { + switch (specialFormExpr->form) { + case protocol::Form::IF: + return optimizeIfSpecialForm(specialFormExpr); + case protocol::Form::NULL_IF: + VELOX_USER_FAIL("NULL_IF specialForm not supported"); + break; + case protocol::Form::IS_NULL: + return optimizeIsNullSpecialForm(specialFormExpr); + case protocol::Form::AND: + return optimizeAndSpecialForm(specialFormExpr); + case protocol::Form::OR: + return optimizeOrSpecialForm(specialFormExpr); + case protocol::Form::COALESCE: + return optimizeCoalesceSpecialForm(specialFormExpr); + case protocol::Form::IN: + case protocol::Form::DEREFERENCE: + case protocol::Form::SWITCH: + case protocol::Form::WHEN: + case protocol::Form::ROW_CONSTRUCTOR: + case protocol::Form::BIND: + default: + break; + } + + return specialFormExpr; +} + +json::array_t RowExpressionOptimizer::optimizeExpressions( + const json::array_t& input, + const std::string& optimizerLevel, + const std::string& currentUser) { + const auto numExpr = input.size(); + json::array_t output = json::array(); + for (auto i = 0; i < numExpr; i++) { + // TODO: current_user to be evaluated in the native plugin and will not be + // sent to the sidecar. + if (input[i].contains("displayName") && + input[i].at("displayName") == "$current_user") { + output.emplace_back(rowExpressionConverter_.getCurrentUser(currentUser)); + continue; + } + + std::shared_ptr inputRowExpr = input[i]; + if (const auto special = + std::dynamic_pointer_cast( + inputRowExpr)) { + inputRowExpr = optimizeSpecialForm(special); + } + auto typedExpr = veloxExprConverter_.toVeloxExpr(inputRowExpr); + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + auto compiledExpr = compiledExprs[0]; + json resultJson; + + if (optimizerLevel == "EVALUATED") { + if (compiledExpr->isConstant()) { + resultJson = rowExpressionConverter_.veloxExprToRowExpression( + compiledExpr, input[i]); + } else { + // Evaluate non-deterministic expressions with constant inputs. + VELOX_USER_CHECK(!compiledExpr->isDeterministic()); + std::vector compiledExprInputTypes; + std::vector compiledExprInputs; + for (const auto& exprInput : compiledExpr->inputs()) { + VELOX_USER_CHECK( + exprInput->isConstant(), + "Inputs to non-deterministic expression to be evaluated must be constant"); + const auto inputAsConstExpr = + std::dynamic_pointer_cast(exprInput); + compiledExprInputs.emplace_back(inputAsConstExpr->value()); + compiledExprInputTypes.emplace_back(exprInput->type()); + } + + const auto inputVector = std::make_shared( + pool_.get(), + ROW(std::move(compiledExprInputTypes)), + nullptr, + 1, + compiledExprInputs); + exec::EvalCtx evalCtx(execCtx_.get(), &exprSet, inputVector.get()); + std::vector results(1); + SelectivityVector rows(1); + exprSet.eval(rows, evalCtx, results); + auto res = results.front(); + DecodedVector decoded(*res, rows); + VELOX_USER_CHECK(decoded.size() == 1); + const auto constExpr = VELOX_DYNAMIC_TYPE_DISPATCH( + getConstantExpr, + res->typeKind(), + res->type(), + decoded, + pool_.get()); + resultJson = + rowExpressionConverter_.getConstantRowExpression(constExpr); + } + } else { + resultJson = rowExpressionConverter_.veloxExprToRowExpression( + compiledExpr, input[i]); + } + + output.push_back(resultJson); + } + return output; +} + +void RowExpressionOptimizer::optimize( + proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + try { + auto timezone = + message->getHeaders().getSingleOrEmpty("X-Presto-Time-Zone"); + auto currentUser = message->getHeaders().getSingleOrEmpty("X-Presto-User"); + auto optimizerLevel = message->getHeaders().getSingleOrEmpty( + "X-Presto-Expression-Optimizer-Level"); + + std::unordered_map config( + {{core::QueryConfig::kSessionTimezone, timezone}, + {core::QueryConfig::kAdjustTimestampToTimezone, "true"}}); + auto queryCtx = + core::QueryCtx::create(nullptr, core::QueryConfig{std::move(config)}); + execCtx_ = std::make_unique(pool_.get(), queryCtx.get()); + + json::array_t inputList = getInputExpressions(body); + json output = optimizeExpressions(inputList, optimizerLevel, currentUser); + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "OK") + .header( + proxygen::HTTP_HEADER_CONTENT_TYPE, http::kMimeTypeApplicationJson) + .body(output.dump()) + .sendWithEOM(); + } catch (const VeloxUserError& e) { + VLOG(1) << "VeloxUserError during expression evaluation: " << e.what(); + http::sendErrorResponse(downstream, e.what()); + } catch (const VeloxException& e) { + VLOG(1) << "VeloxException during expression evaluation: " << e.what(); + http::sendErrorResponse(downstream, e.what()); + } catch (const std::exception& e) { + VLOG(1) << "std::exception during expression evaluation: " << e.what(); + http::sendErrorResponse(downstream, e.what()); + } +} + +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.h b/presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.h new file mode 100644 index 0000000000000..b30b18e1a4332 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/RowExpressionOptimizer.h @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/main/http/HttpServer.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/core/QueryCtx.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/Expr.h" +#include "velox/serializers/PrestoSerializer.h" + +using namespace facebook::velox; + +namespace facebook::presto::expression { + +using RowExpressionPtr = std::shared_ptr; +using SpecialFormExpressionPtr = + std::shared_ptr; + +// Helper class to convert Velox Expr of different types to the respective kind +// of Presto RowExpression. +class RowExpressionConverter { + public: + explicit RowExpressionConverter( + const std::shared_ptr& pool) + : pool_(pool), veloxToPrestoOperatorMap_(veloxToPrestoOperatorMap()) {} + + std::shared_ptr getConstantRowExpression( + const std::shared_ptr& constantExpr); + + std::shared_ptr getCurrentUser( + const std::string& currentUser); + + json veloxExprToRowExpression( + const exec::ExprPtr& expr, + const json& inputRowExpr); + + protected: + std::string getValueBlock(const VectorPtr& vector); + + json getRowConstructorSpecialForm( + const exec::ExprPtr& expr, + const json& inputRowExpr); + + std::pair getSwitchSpecialFormArgs( + const exec::ExprPtr& expr, + const json& input); + + json getSpecialForm(const exec::ExprPtr& expr, const json& inputRowExpr); + + json toConstantRowExpression(const exec::ExprPtr& expr); + + json toCallRowExpression(const exec::ExprPtr& expr, const json& input); + + const std::shared_ptr pool_; + const std::unordered_map veloxToPrestoOperatorMap_; + const std::unique_ptr serde_ = + std::make_unique(); +}; + +class RowExpressionOptimizer { + public: + explicit RowExpressionOptimizer() + : pool_(memory::MemoryManager::getInstance()->addLeafPool( + "RowExpressionOptimizer")), + veloxExprConverter_(pool_.get(), &typeParser_), + rowExpressionConverter_(RowExpressionConverter(pool_)) {} + + /// Optimize expressions sent along the proxygen endpoint '/v1/expressions'. + void optimize( + proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream); + + protected: + /// Converts protocol::RowExpression into a velox expression with constant + /// folding enabled during velox expression compilation. + exec::ExprPtr compileExpression(const RowExpressionPtr& inputRowExpr); + + RowExpressionPtr optimizeAndSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeIfSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeIsNullSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeOrSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + RowExpressionPtr optimizeCoalesceSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Optimizes special form expressions. Optimization rules borrowed from + /// Presto function visitSpecialForm() in RowExpressionInterpreter.java. + RowExpressionPtr optimizeSpecialForm( + const SpecialFormExpressionPtr& specialFormExpr); + + /// Optimizes and constant folds each expression from input json array and + /// returns an array of expressions that are optimized and constant folded. + /// Each expression in the input array is optimized with helper functions + /// optimizeSpecialForm (applicable only for special form expressions) and + /// optimizeExpression. The optimized expression is also evaluated if the + /// optimization level in the header of http request made to 'v1/expressions' + /// is 'EVALUATED'. optimizeExpression uses RowExpressionConverter to convert + /// Velox expression(s) to their corresponding Presto RowExpression(s). + json::array_t optimizeExpressions( + const json::array_t& input, + const std::string& optimizationLevel, + const std::string& currentUser); + + const std::shared_ptr pool_; + std::unique_ptr execCtx_; + TypeParser typeParser_; + VeloxExprConverter veloxExprConverter_; + RowExpressionConverter rowExpressionConverter_; +}; +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt new file mode 100644 index 0000000000000..5879d03b9e006 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_executable(presto_expression_optimizer_test RowExpressionOptimizerTest.cpp) + +add_test(presto_expression_optimizer_test presto_expression_optimizer_test) + +target_link_libraries( + presto_expression_optimizer_test + presto_expression_optimizer + presto_http + velox_exec_test_lib + velox_presto_serializer + GTest::gtest + GTest::gtest_main + ${PROXYGEN_LIBRARIES}) diff --git a/presto-native-execution/presto_cpp/main/expression/tests/RowExpressionOptimizerTest.cpp b/presto-native-execution/presto_cpp/main/expression/tests/RowExpressionOptimizerTest.cpp new file mode 100644 index 0000000000000..9560ee5656b86 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/RowExpressionOptimizerTest.cpp @@ -0,0 +1,187 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/expression/RowExpressionOptimizer.h" +#include +#include +#include +#include +#include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/http/tests/HttpTestBase.h" +#include "velox/exec/OutputBufferManager.h" +#include "velox/expression/RegisterSpecialForm.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/VectorStream.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace { +std::string getDataPath(const std::string& fileName) { + std::string currentPath = fs::current_path().c_str(); + + if (boost::algorithm::ends_with(currentPath, "fbcode")) { + return currentPath + + "/github/presto-trunk/presto-native-execution/presto_cpp/main/expression/tests/data/" + + fileName; + } + + if (boost::algorithm::ends_with(currentPath, "fbsource")) { + return currentPath + "/third-party/presto_cpp/main/expression/tests/data/" + + fileName; + } + + // CLion runs the tests from cmake-build-release/ or cmake-build-debug/ + // directory. Hard-coded json files are not copied there and test fails with + // file not found. Fixing the path so that we can trigger these tests from + // CLion. + boost::algorithm::replace_all(currentPath, "cmake-build-release/", ""); + boost::algorithm::replace_all(currentPath, "cmake-build-debug/", ""); + + return currentPath + "/data/" + fileName; +} +} // namespace + +// RowExpressionOptimizerTest only tests basic expression optimization via the +// 'v1/expressions' endpoint. End to end tests for different expression types +// can be found in TestDelegatingExpressionOptimizer.java, in the module +// presto-native-sidecar-plugin. +class RowExpressionOptimizerTest + : public ::testing::Test, + public facebook::velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions("presto.default."); + exec::registerFunctionCallToSpecialForms(); + + auto httpServer = std::make_unique( + httpSrvIOExecutor_, + std::make_unique( + folly::SocketAddress("127.0.0.1", 0))); + driverExecutor_ = std::make_unique(4); + rowExpressionOptimizer_ = + std::make_unique(); + httpServer->registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + return rowExpressionOptimizer_->optimize(message, body, downstream); + }); + httpServerWrapper_ = + std::make_unique(std::move(httpServer)); + auto address = httpServerWrapper_->start().get(); + client_ = clientFactory_.newClient( + address, + std::chrono::milliseconds(100'000), + std::chrono::milliseconds(0), + false, + pool_); + } + + void TearDown() override { + if (httpServerWrapper_) { + httpServerWrapper_->stop(); + } + } + + static std::string getHttpBody( + const std::unique_ptr& response) { + std::ostringstream oss; + auto iobufs = response->consumeBody(); + for (auto& body : iobufs) { + oss << std::string((const char*)body->data(), body->length()); + } + return oss.str(); + } + + void validateHttpResponse( + const std::string& inputStr, + const std::string& expectedStr) { + http::RequestBuilder() + .method(proxygen::HTTPMethod::POST) + .url("/v1/expressions") + .send(client_.get(), inputStr) + .via(driverExecutor_.get()) + .thenValue([expectedStr](std::unique_ptr response) { + VELOX_USER_CHECK_EQ( + response->headers()->getStatusCode(), http::kHttpOk); + if (response->hasError()) { + VELOX_USER_FAIL( + "Expression evaluation failed: {}", response->error()); + } + + auto resStr = getHttpBody(response); + auto resJson = json::parse(resStr); + ASSERT_TRUE(resJson.is_array()); + auto expectedJson = json::parse(expectedStr); + ASSERT_TRUE(expectedJson.is_array()); + EXPECT_EQ(expectedJson.size(), resJson.size()); + auto size = resJson.size(); + for (auto i = 0; i < size; i++) { + EXPECT_EQ(resJson[i], expectedJson[i]); + } + }) + .thenError( + folly::tag_t{}, [&](const std::exception& e) { + VLOG(1) << "Expression evaluation failed: " << e.what(); + }); + } + + void testFile(const std::string& prefix) { + std::string input = slurp(getDataPath(fmt::format("{}Input.json", prefix))); + auto inputExpressions = json::parse(input); + std::string output = + slurp(getDataPath(fmt::format("{}Expected.json", prefix))); + auto expectedExpressions = json::parse(output); + + validateHttpResponse(inputExpressions.dump(), expectedExpressions.dump()); + } + + std::unique_ptr rowExpressionOptimizer_; + std::unique_ptr httpServerWrapper_; + HttpClientFactory clientFactory_; + std::shared_ptr client_; + std::shared_ptr httpSrvIOExecutor_{ + std::make_shared(8)}; + std::unique_ptr driverExecutor_; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool("RowExpressionOptimizerTest")}; +}; + +TEST_F(RowExpressionOptimizerTest, simple) { + // File SimpleExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // select 1 + 2; + // select abs(-11) + ceil(cast(3.4 as double)) + floor(cast(5.6 as double)); + // select 2 between 1 and 3; + // Simple expression evaluation with constant folding is verified here. + testFile("SimpleExpressions"); +} + +TEST_F(RowExpressionOptimizerTest, specialFormRewrites) { + // File SpecialExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // select if(1 < 2, 2, 3); + // select (1 < 2) and (2 < 3); + // select (1 < 2) or (2 < 3); + // Special form expression rewrites are verified here. + testFile("SpecialForm"); +} diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json new file mode 100644 index 0000000000000..e9b014fecd3b2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "double", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAADRA" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json new file mode 100644 index 0000000000000..dd2a0497703eb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SimpleExpressionsInput.json @@ -0,0 +1,278 @@ +[ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAACwAAAA==" + } + ], + "displayName": "NEGATION", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$negation", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "abs", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.abs", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAACIAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ceil", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.ceil", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADgAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "floor", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.floor", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json new file mode 100644 index 0000000000000..2ce6acb1ab46e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "boolean", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] diff --git a/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json new file mode 100644 index 0000000000000..77802722541b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/expression/tests/data/SpecialFormInput.json @@ -0,0 +1,193 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "form": "IF", + "returnType": "integer" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "AND", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "OR", + "returnType": "boolean" + } +] diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 2c2b2a3c5ea00..791e215fb0f3c 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -33,32 +33,10 @@ std::string toJsonString(const T& value) { } std::string mapScalarFunction(const std::string& name) { - static const std::unordered_map kFunctionNames = { - // Operator overrides: com.facebook.presto.common.function.OperatorType - {"presto.default.$operator$add", "presto.default.plus"}, - {"presto.default.$operator$between", "presto.default.between"}, - {"presto.default.$operator$divide", "presto.default.divide"}, - {"presto.default.$operator$equal", "presto.default.eq"}, - {"presto.default.$operator$greater_than", "presto.default.gt"}, - {"presto.default.$operator$greater_than_or_equal", "presto.default.gte"}, - {"presto.default.$operator$is_distinct_from", - "presto.default.distinct_from"}, - {"presto.default.$operator$less_than", "presto.default.lt"}, - {"presto.default.$operator$less_than_or_equal", "presto.default.lte"}, - {"presto.default.$operator$modulus", "presto.default.mod"}, - {"presto.default.$operator$multiply", "presto.default.multiply"}, - {"presto.default.$operator$negation", "presto.default.negate"}, - {"presto.default.$operator$not_equal", "presto.default.neq"}, - {"presto.default.$operator$subtract", "presto.default.minus"}, - {"presto.default.$operator$subscript", "presto.default.subscript"}, - // Special form function overrides. - {"presto.default.in", "in"}, - }; - std::string lowerCaseName = boost::to_lower_copy(name); - auto it = kFunctionNames.find(lowerCaseName); - if (it != kFunctionNames.end()) { + auto it = kPrestoOperatorMap.find(lowerCaseName); + if (it != kPrestoOperatorMap.end()) { return it->second; } @@ -102,6 +80,15 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) { } // namespace +const std::unordered_map veloxToPrestoOperatorMap() { + std::unordered_map veloxToPrestoOperatorMap; + for (const auto& entry : kPrestoOperatorMap) { + veloxToPrestoOperatorMap[entry.second] = entry.first; + } + veloxToPrestoOperatorMap.insert({"cast", "presto.default.$operator$cast"}); + return veloxToPrestoOperatorMap; +} + velox::variant VeloxExprConverter::getConstantValue( const velox::TypePtr& type, const protocol::Block& block) const { diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index f63a84ec35ad2..06c326dfddc1d 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -20,6 +20,30 @@ namespace facebook::presto { +static const std::unordered_map kPrestoOperatorMap = { + // Operator overrides: com.facebook.presto.common.function.OperatorType + {"presto.default.$operator$add", "presto.default.plus"}, + {"presto.default.$operator$between", "presto.default.between"}, + {"presto.default.$operator$divide", "presto.default.divide"}, + {"presto.default.$operator$equal", "presto.default.eq"}, + {"presto.default.$operator$greater_than", "presto.default.gt"}, + {"presto.default.$operator$greater_than_or_equal", "presto.default.gte"}, + {"presto.default.$operator$is_distinct_from", + "presto.default.distinct_from"}, + {"presto.default.$operator$less_than", "presto.default.lt"}, + {"presto.default.$operator$less_than_or_equal", "presto.default.lte"}, + {"presto.default.$operator$modulus", "presto.default.mod"}, + {"presto.default.$operator$multiply", "presto.default.multiply"}, + {"presto.default.$operator$negation", "presto.default.negate"}, + {"presto.default.$operator$not_equal", "presto.default.neq"}, + {"presto.default.$operator$subtract", "presto.default.minus"}, + {"presto.default.$operator$subscript", "presto.default.subscript"}, + // Special form function overrides. + {"presto.default.in", "in"}, +}; + +const std::unordered_map veloxToPrestoOperatorMap(); + class VeloxExprConverter { public: VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser)