From 822f79fa4096f24717925dc60495f746b32cd1ec Mon Sep 17 00:00:00 2001 From: Pramod Satya Date: Thu, 2 May 2024 13:08:59 -0700 Subject: [PATCH] [native] Add proxygen endpoint for expression evaluation --- .../presto_cpp/main/CMakeLists.txt | 2 + .../presto_cpp/main/PrestoServer.cpp | 3 + .../presto_cpp/main/PrestoServer.h | 2 + .../presto_cpp/main/eval/CMakeLists.txt | 33 ++ .../presto_cpp/main/eval/PrestoExprEval.cpp | 359 ++++++++++++++++++ .../presto_cpp/main/eval/PrestoExprEval.h | 62 +++ .../presto_cpp/main/eval/tests/CMakeLists.txt | 26 ++ .../main/eval/tests/PrestoExprEvalTest.cpp | 178 +++++++++ .../tests/data/SimpleExpressionsExpected.json | 17 + .../tests/data/SimpleExpressionsInput.json | 278 ++++++++++++++ .../eval/tests/data/SpecialFormExpected.json | 17 + .../eval/tests/data/SpecialFormInput.json | 193 ++++++++++ .../main/types/PrestoToVeloxExpr.cpp | 29 +- .../presto_cpp/main/types/PrestoToVeloxExpr.h | 24 ++ 14 files changed, 1201 insertions(+), 22 deletions(-) create mode 100644 presto-native-execution/presto_cpp/main/eval/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp create mode 100644 presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h create mode 100644 presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsExpected.json create mode 100644 presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsInput.json create mode 100644 presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormExpected.json create mode 100644 presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormInput.json diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 30ba84dc5461d..0e2a5572e6667 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(operators) add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) +add_subdirectory(eval) add_subdirectory(thrift) add_library( @@ -49,6 +50,7 @@ target_link_libraries( presto_exception presto_http presto_operators + presto_expr_eval velox_aggregates velox_caching velox_common_base diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 5ee52935cc224..c4a63bb92e654 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -477,6 +477,9 @@ void PrestoServer::run() { taskManager_->getQueryContextManager()->getSessionProperties(); http::sendOkResponse(downstream, sessionProperties.serialize()); }); + + prestoExprEval_ = std::make_unique(pool_); + prestoExprEval_->registerUris(*httpServer_); } std::string taskUri; diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index bee2d8d43391a..3681564708c8b 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -25,6 +25,7 @@ #include "presto_cpp/main/PeriodicHeartbeatManager.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/PrestoServerOperations.h" +#include "presto_cpp/main/eval/PrestoExprEval.h" #include "presto_cpp/main/types/VeloxPlanValidator.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/MemoryAllocator.h" @@ -277,6 +278,7 @@ class PrestoServer { std::string address_; std::string nodeLocation_; folly::SSLContextPtr sslContext_; + std::unique_ptr prestoExprEval_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/eval/CMakeLists.txt b/presto-native-execution/presto_cpp/main/eval/CMakeLists.txt new file mode 100644 index 0000000000000..33a59ed2585d9 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_expr_eval PrestoExprEval.cpp) + +target_link_libraries( + presto_expr_eval + presto_type_converter + presto_types + presto_protocol + presto_http + velox_coverage_util + velox_parse_expression + velox_parse_parser + velox_presto_serializer + velox_serialization + velox_type_parser + ${FOLLY_WITH_DEPENDENCIES} + ${PROXYGEN_LIBRARIES}) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp new file mode 100644 index 0000000000000..398052b2f95d0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp @@ -0,0 +1,359 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/eval/PrestoExprEval.h" +#include +#include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/presto_protocol/presto_protocol.h" +#include "velox/common/encode/Base64.h" +#include "velox/core/Expressions.h" +#include "velox/exec/ExchangeQueue.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/expression/ExprCompiler.h" +#include "velox/expression/FieldReference.h" +#include "velox/expression/LambdaExpr.h" +#include "velox/parse/Expressions.h" +#include "velox/parse/ExpressionsParser.h" +#include "velox/serializers/PrestoSerializer.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace facebook::presto::eval { + +namespace { + +// ValueBlock in ConstantExpression requires only the column to be serialized +// in PrestoPage format, without the page header. +const std::string serializeValueBlock( + const VectorPtr& vector, + memory::MemoryPool* pool) { + auto numRows = vector->size(); + std::unique_ptr serde = + std::make_unique(); + const IndexRange allRows{0, numRows}; + auto ranges = folly::Range(&allRows, 1); + const auto arena = std::make_unique(pool); + serializer::presto::PrestoOptions paramOptions; + auto stream = std::make_unique( + vector->type(), + std::nullopt, + std::nullopt, + arena.get(), + numRows, + paramOptions); + Scratch scratch; + serde->serializeColumn(vector, ranges, stream.get(), scratch); + IOBufOutputStream ostream(*pool); + stream->flush(&ostream); + auto resultBuf = ostream.getIOBuf(); + auto bufLen = resultBuf->length(); + resultBuf->gather(bufLen); + return velox::encoding::Base64::encode( + reinterpret_cast(resultBuf->data()), bufLen); +} + +std::shared_ptr getConstantExpression( + std::shared_ptr constantExpr, + memory::MemoryPool* pool) { + protocol::ConstantExpression rexpr; + rexpr.type = constantExpr->type()->toString(); + rexpr.valueBlock.data = serializeValueBlock(constantExpr->value(), pool); + return std::make_shared(rexpr); +} + +json toVariableReferenceExpression(exec::FieldReference* fieldReference) { + protocol::VariableReferenceExpression vexpr; + vexpr.name = fieldReference->name(); + vexpr._type = "variable"; + vexpr.type = fieldReference->type()->toString(); + json res; + protocol::to_json(res, vexpr); + return res; +} +} // namespace + +// Compiles protocol::RowExpression in velox, with constant folding enabled, +// into a velox expression. +std::shared_ptr PrestoExprEval::compileExpression( + std::shared_ptr inputRowExpr) { + auto typedExpr = exprConverter_.toVeloxExpr(inputRowExpr); + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + // Constant folds the expression. + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + return compiledExprs[0]; +} + +// Optimizes special form expressions. Optimization rules borrowed from +// Presto (function visitSpecialForm in +// presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java). +std::shared_ptr PrestoExprEval::optimizeSpecialForm( + std::shared_ptr specialFormExpr) { + switch (specialFormExpr->form) { + case protocol::Form::IF: { + auto condition = specialFormExpr->arguments[0]; + auto expr = compileExpression(condition); + + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0)) { + return specialFormExpr->arguments[1]; + } + return specialFormExpr->arguments[2]; + } + } + break; + } + case protocol::Form::NULL_IF: + VELOX_USER_FAIL("NULL_IF not supported in specialForm"); + break; + case protocol::Form::IS_NULL: { + auto value = specialFormExpr->arguments[0]; + auto expr = compileExpression(specialFormExpr); + + if (auto constantExpr = + std::dynamic_pointer_cast(expr)) { + if (constantExpr->value()->isNullAt(0)) { + return getConstantExpression(constantExpr, pool_.get()); + } + } + break; + } + case protocol::Form::AND: { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto lexpr = compileExpression(left); + bool isLeftNull; + + if (auto constantExpr = + std::dynamic_pointer_cast(lexpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0) == false) { + return getConstantExpression(constantExpr, pool_.get()); + } else { + return right; + } + } + } + + auto rexpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rexpr)) { + if (isLeftNull && constantExpr->value()->isNullAt(0)) { + return getConstantExpression(constantExpr, pool_.get()); + } + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0) == true) { + return left; + } + return right; + } + } + break; + } + case protocol::Form::OR: { + auto left = specialFormExpr->arguments[0]; + auto right = specialFormExpr->arguments[1]; + auto lexpr = compileExpression(left); + bool isLeftNull; + + if (auto constantExpr = + std::dynamic_pointer_cast(lexpr)) { + isLeftNull = constantExpr->value()->isNullAt(0); + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0) == true) { + return getConstantExpression(constantExpr, pool_.get()); + } + return right; + } + } + + auto rexpr = compileExpression(right); + if (auto constantExpr = + std::dynamic_pointer_cast(rexpr)) { + if (isLeftNull && constantExpr->value()->isNullAt(0)) { + return getConstantExpression(constantExpr, pool_.get()); + } + if (auto constVector = + constantExpr->value()->as>()) { + if (constVector->valueAt(0) == false) { + return left; + } + return right; + } + } + break; + } + case protocol::Form::IN: { + auto args = specialFormExpr->arguments; + VELOX_USER_CHECK(args.size() >= 2, "values must not be empty"); + auto target = args[0]; + if (target == nullptr) { + return nullptr; + } + break; + } + case protocol::Form::DEREFERENCE: { + auto args = specialFormExpr->arguments; + VELOX_USER_CHECK(args.size() == 2); + auto base = args[0]; + if (base == nullptr) { + return nullptr; + } + break; + } + default: + break; + } + return specialFormExpr; +} + +json PrestoExprEval::exprToRowExpression( + std::shared_ptr expr) { + json res; + if (expr->isConstant()) { + // constant + res["@type"] = "constant"; + auto constantExpr = std::dynamic_pointer_cast(expr); + VELOX_USER_CHECK_NOT_NULL(constantExpr); + auto type = constantExpr->type(); + VectorPtr constVector = constantExpr->value(); + res["valueBlock"] = serializeValueBlock(constVector, pool_.get()); + res["type"] = expr->type()->toString(); + } else if (auto fexpr = expr->as()) { + // variable + res = toVariableReferenceExpression(fexpr); + } else if (expr->isSpecialForm()) { + // special + res["@type"] = "special"; + auto inputs = expr->inputs(); + res["arguments"] = json::array(); + for (auto input : inputs) { + res["arguments"].push_back(exprToRowExpression(input)); + } + res["form"] = "BIND"; + res["returnType"] = expr->type()->toString(); + } else if (auto lambda = std::dynamic_pointer_cast(expr)) { + // lambda expressions are currently not optimized. + res["@type"] = "lambda"; + auto inputs = lambda->distinctFields(); + res["arguments"] = json::array(); + res["argumentTypes"] = json::array(); + auto numInputs = inputs.size(); + for (auto i = 0; i < numInputs; i++) { + res["arguments"].push_back(toVariableReferenceExpression(inputs[i])); + res["argumentTypes"].push_back(lambda->type()->childAt(i)->toString()); + } + res["body"] = lambda->toString(); + } else if (auto func = expr->vectorFunction()) { + // call + protocol::Signature signature; + signature.kind = protocol::FunctionKind::SCALAR; + signature.returnType = expr->type()->toString(); + auto signatureName = expr->name(); + // Maps operator name from velox to presto when input is a special form + // expression. + if (inverseScalarMap_.find(expr->name()) != inverseScalarMap_.end()) { + signatureName = inverseScalarMap_.at(expr->name()); + } + signature.name = signatureName; + signature.typeVariableConstraints = {}; + signature.longVariableConstraints = {}; + signature.variableArity = false; + + std::vector argumentTypes; + auto inputs = expr->inputs(); + auto numArgs = inputs.size(); + argumentTypes.reserve(numArgs); + for (auto i = 0; i < numArgs; i++) { + argumentTypes.emplace_back(inputs[i]->type()->toString()); + } + signature.argumentTypes = argumentTypes; + + protocol::BuiltInFunctionHandle builtInFunctionHandle; + builtInFunctionHandle._type = "$static"; + builtInFunctionHandle.signature = signature; + res["@type"] = "call"; + res["displayName"] = expr->name(); + res["functionHandle"] = builtInFunctionHandle; + for (auto input : inputs) { + res["arguments"].push_back(exprToRowExpression(input)); + } + res["returnType"] = expr->type()->toString(); + } else { + VELOX_NYI( + "Unable to convert Velox Expr to Presto RowExpression: {}", + expr->toString()); + } + + return res; +} + +void PrestoExprEval::evaluateExpression( + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + std::ostringstream oss; + for (auto& buf : body) { + oss << std::string((const char*)buf->data(), buf->length()); + } + auto input = json::parse(oss.str()); + auto numExpr = input.size(); + nlohmann::json output = json::array(); + + for (auto i = 0; i < numExpr; i++) { + std::shared_ptr inputRowExpr = input[i]; + // Check if special form expression can be optimized. + if (auto special = + std::dynamic_pointer_cast( + inputRowExpr)) { + inputRowExpr = optimizeSpecialForm(special); + } + + auto sourceLocation = inputRowExpr->sourceLocation; + auto compiledExpr = compileExpression(inputRowExpr); + json resultJson = exprToRowExpression(compiledExpr); + if (sourceLocation) { + json j; + protocol::to_json(j, *inputRowExpr->sourceLocation); + resultJson["sourceLocation"] = j; + } + output.push_back(resultJson); + } + + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "") + .header( + proxygen::HTTP_HEADER_CONTENT_TYPE, http::kMimeTypeApplicationJson) + .body(output.dump()) + .sendWithEOM(); +} + +void PrestoExprEval::registerUris(http::HttpServer& server) { + server.registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* /*message*/, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + return evaluateExpression(body, downstream); + }); +} +} // namespace facebook::presto::eval diff --git a/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h new file mode 100644 index 0000000000000..6a40dd580c357 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/main/http/HttpServer.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/core/QueryCtx.h" +#include "velox/expression/Expr.h" + +namespace facebook::presto::eval { + +class PrestoExprEval { + public: + PrestoExprEval(std::shared_ptr pool) + : pool_(pool), + queryCtx_(facebook::velox::core::QueryCtx::create()), + execCtx_{std::make_unique( + pool.get(), + queryCtx_.get())}, + exprConverter_(pool.get(), &typeParser_), + inverseScalarMap_(inverseScalarMap()){}; + + void registerUris(http::HttpServer& server); + + /// Evaluate expressions sent along /v1/expressions endpoint. + void evaluateExpression( + const std::vector>& body, + proxygen::ResponseHandler* downstream); + + std::shared_ptr optimizeSpecialForm( + std::shared_ptr specialFormExpr); + + protected: + json exprToRowExpression(std::shared_ptr expr); + + std::shared_ptr compileExpression( + std::shared_ptr inputRowExpr); + + const std::shared_ptr pool_; + const std::shared_ptr queryCtx_; + const std::unique_ptr execCtx_; + VeloxExprConverter exprConverter_; + TypeParser typeParser_; + bool isLambda_; + std::shared_ptr lambdaTypedExpr_; + const std::unordered_map inverseScalarMap_; + std::unordered_map variableToSourceLocation_; +}; +} // namespace facebook::presto::eval diff --git a/presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt new file mode 100644 index 0000000000000..b2572f861efe3 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_executable(presto_expr_eval_test PrestoExprEvalTest.cpp) + +add_test(presto_expr_eval_test presto_expr_eval_test) + +target_link_libraries( + presto_expr_eval_test + presto_expr_eval + presto_http + velox_exec_test_lib + velox_presto_serializer + gtest + gtest_main + ${PROXYGEN_LIBRARIES}) diff --git a/presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp b/presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp new file mode 100644 index 0000000000000..1495e101923f2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp @@ -0,0 +1,178 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/eval/PrestoExprEval.h" +#include +#include +#include +#include +#include "presto_cpp/main/common/tests/test_json.h" +#include "presto_cpp/main/http/tests/HttpTestBase.h" +#include "velox/exec/OutputBufferManager.h" +#include "velox/expression/RegisterSpecialForm.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/VectorStream.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace { +std::string getDataPath(const std::string& fileName) { + std::string currentPath = fs::current_path().c_str(); + + if (boost::algorithm::ends_with(currentPath, "fbcode")) { + return currentPath + + "/github/presto-trunk/presto-native-execution/presto_cpp/main/types/tests/data/" + + fileName; + } + + if (boost::algorithm::ends_with(currentPath, "fbsource")) { + return currentPath + "/third-party/presto_cpp/main/types/tests/data/" + + fileName; + } + + // CLion runs the tests from cmake-build-release/ or cmake-build-debug/ + // directory. Hard-coded json files are not copied there and test fails with + // file not found. Fixing the path so that we can trigger these tests from + // CLion. + boost::algorithm::replace_all(currentPath, "cmake-build-release/", ""); + boost::algorithm::replace_all(currentPath, "cmake-build-debug/", ""); + + return currentPath + "/data/" + fileName; +} +} // namespace + +class PrestoExprEvalTest : public ::testing::Test, + public facebook::velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions("presto.default."); + exec::registerFunctionCallToSpecialForms(); + + auto httpServer = std::make_unique( + httpSrvIOExecutor_, + std::make_unique( + folly::SocketAddress("127.0.0.1", 0))); + prestoExprEval_ = std::make_unique(pool_); + prestoExprEval_->registerUris(*httpServer.get()); + httpServerWrapper_ = + std::make_unique(std::move(httpServer)); + auto address = httpServerWrapper_->start().get(); + client_ = clientFactory_.newClient( + address, + std::chrono::milliseconds(100'000), + std::chrono::milliseconds(0), + false, + pool_); + } + + void TearDown() override { + if (httpServerWrapper_) { + httpServerWrapper_->stop(); + } + } + + std::string getHttpBody(const std::unique_ptr& response) { + std::ostringstream oss; + auto iobufs = response->consumeBody(); + for (auto& body : iobufs) { + oss << std::string((const char*)body->data(), body->length()); + } + return oss.str(); + } + + void validateHttpResponse( + const std::string& inputStr, + const std::string& expectedStr) { + auto driverExecutor = std::make_shared( + 4, std::make_shared("Driver")); + const auto url = "/v1/expressions"; + http::RequestBuilder() + .method(proxygen::HTTPMethod::POST) + .url(url) + .send(client_.get(), inputStr) + .via(driverExecutor.get()) + .thenValue( + [expectedStr, this](std::unique_ptr response) { + VELOX_USER_CHECK_EQ( + response->headers()->getStatusCode(), http::kHttpOk); + if (response->hasError()) { + VELOX_USER_FAIL( + "Expression evaluation failed: {}", response->error()); + } + + auto resStr = getHttpBody(response); + auto resJson = json::parse(resStr); + ASSERT_TRUE(resJson.is_array()); + auto expectedJson = json::parse(expectedStr); + ASSERT_TRUE(expectedJson.is_array()); + EXPECT_EQ(expectedJson.size(), resJson.size()); + auto sz = resJson.size(); + for (auto i = 0; i < sz; i++) { + json result = resJson[i]; + json expected = expectedJson[i]; + EXPECT_EQ(result, expected); + } + }) + .thenError( + folly::tag_t{}, [&](const std::exception& e) { + VLOG(1) << "Expression evaluation failed: " << e.what(); + }); + } + + void testFile(const std::string& prefix) { + std::string input = slurp(getDataPath(fmt::format("{}Input.json", prefix))); + auto inputExprs = json::parse(input); + std::string output = + slurp(getDataPath(fmt::format("{}Expected.json", prefix))); + auto expectedExprs = json::parse(output); + + validateHttpResponse(inputExprs.dump(), expectedExprs.dump()); + } + + std::unique_ptr prestoExprEval_; + std::unique_ptr httpServerWrapper_; + HttpClientFactory clientFactory_; + std::shared_ptr client_; + std::shared_ptr httpSrvIOExecutor_{ + std::make_shared(8)}; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool("PrestoExprEvalTest")}; +}; + +TEST_F(PrestoExprEvalTest, simple) { + // File SimpleExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // select 1 + 2; + // select abs(-11) + ceil(cast(3.4 as double)) + floor(cast(5.6 as double)); + // select 2 between 1 and 3; + // Simple expression evaluation with constant folding is verified here. + testFile("SimpleExpressions"); +} + +TEST_F(PrestoExprEvalTest, specialFormRewrites) { + // File SpecialExpressions{Input|Expected}.json contain the input and expected + // JSON representing the RowExpressions resulting from the following queries: + // select if(1 < 2, 2, 3); + // select (1 < 2) and (2 < 3); + // select (1 < 2) or (2 < 3); + // The special form expression rewrites is verified here. + testFile("SpecialForm"); +} diff --git a/presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsExpected.json b/presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsExpected.json new file mode 100644 index 0000000000000..83d2cfb896af9 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "INTEGER", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + }, + { + "@type": "constant", + "type": "DOUBLE", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAAAAAAAAAADRA" + }, + { + "@type": "constant", + "type": "BOOLEAN", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsInput.json b/presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsInput.json new file mode 100644 index 0000000000000..582483a45c9bd --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/data/SimpleExpressionsInput.json @@ -0,0 +1,278 @@ +[ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAACwAAAA==" + } + ], + "displayName": "NEGATION", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$negation", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "abs", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.abs", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAACIAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ceil", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.ceil", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "decimal(2,1)", + "valueBlock": "CgAAAExPTkdfQVJSQVkBAAAAADgAAAAAAAAA" + } + ], + "displayName": "CAST", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "decimal(2,1)" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$cast", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "floor", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.floor", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + } + ], + "displayName": "ADD", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "double", + "double" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "double", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "double" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "BETWEEN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$between", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } +] \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormExpected.json b/presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormExpected.json new file mode 100644 index 0000000000000..6ab2b2f854c77 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormExpected.json @@ -0,0 +1,17 @@ +[ + { + "@type": "constant", + "type": "INTEGER", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "BOOLEAN", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + }, + { + "@type": "constant", + "type": "BOOLEAN", + "valueBlock": "CgAAAEJZVEVfQVJSQVkBAAAAAAE=" + } +] \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormInput.json b/presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormInput.json new file mode 100644 index 0000000000000..74de038e0d08d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/data/SpecialFormInput.json @@ -0,0 +1,193 @@ +[ + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "form": "IF", + "returnType": "integer" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "AND", + "returnType": "boolean" + }, + { + "@type": "special", + "arguments": [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAgAAAA==" + }, + { + "@type": "constant", + "type": "integer", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAwAAAA==" + } + ], + "displayName": "LESS_THAN", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$less_than", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + } + ], + "form": "OR", + "returnType": "boolean" + } +] \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 52a00a6461a56..34f560fec0149 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -33,28 +33,6 @@ std::string toJsonString(const T& value) { } std::string mapScalarFunction(const std::string& name) { - static const std::unordered_map kFunctionNames = { - // Operator overrides: com.facebook.presto.common.function.OperatorType - {"presto.default.$operator$add", "presto.default.plus"}, - {"presto.default.$operator$between", "presto.default.between"}, - {"presto.default.$operator$divide", "presto.default.divide"}, - {"presto.default.$operator$equal", "presto.default.eq"}, - {"presto.default.$operator$greater_than", "presto.default.gt"}, - {"presto.default.$operator$greater_than_or_equal", "presto.default.gte"}, - {"presto.default.$operator$is_distinct_from", - "presto.default.distinct_from"}, - {"presto.default.$operator$less_than", "presto.default.lt"}, - {"presto.default.$operator$less_than_or_equal", "presto.default.lte"}, - {"presto.default.$operator$modulus", "presto.default.mod"}, - {"presto.default.$operator$multiply", "presto.default.multiply"}, - {"presto.default.$operator$negation", "presto.default.negate"}, - {"presto.default.$operator$not_equal", "presto.default.neq"}, - {"presto.default.$operator$subtract", "presto.default.minus"}, - {"presto.default.$operator$subscript", "presto.default.subscript"}, - // Special form function overrides. - {"presto.default.in", "in"}, - }; - std::string lowerCaseName = boost::to_lower_copy(name); auto it = kFunctionNames.find(lowerCaseName); @@ -102,6 +80,13 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) { } // namespace +const std::unordered_map inverseScalarMap() { + std::unordered_map inverseMap; + for (const auto& entry : kFunctionNames) + inverseMap[entry.second] = entry.first; + return inverseMap; +} + velox::variant VeloxExprConverter::getConstantValue( const velox::TypePtr& type, const protocol::Block& block) const { diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index 6e93c675a55f5..1a8ab8dceaec4 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -20,6 +20,30 @@ namespace facebook::presto { +static const std::unordered_map kFunctionNames = { + // Operator overrides: com.facebook.presto.common.function.OperatorType + {"presto.default.$operator$add", "presto.default.plus"}, + {"presto.default.$operator$between", "presto.default.between"}, + {"presto.default.$operator$divide", "presto.default.divide"}, + {"presto.default.$operator$equal", "presto.default.eq"}, + {"presto.default.$operator$greater_than", "presto.default.gt"}, + {"presto.default.$operator$greater_than_or_equal", "presto.default.gte"}, + {"presto.default.$operator$is_distinct_from", + "presto.default.distinct_from"}, + {"presto.default.$operator$less_than", "presto.default.lt"}, + {"presto.default.$operator$less_than_or_equal", "presto.default.lte"}, + {"presto.default.$operator$modulus", "presto.default.mod"}, + {"presto.default.$operator$multiply", "presto.default.multiply"}, + {"presto.default.$operator$negation", "presto.default.negate"}, + {"presto.default.$operator$not_equal", "presto.default.neq"}, + {"presto.default.$operator$subtract", "presto.default.minus"}, + {"presto.default.$operator$subscript", "presto.default.subscript"}, + // Special form function overrides. + {"presto.default.in", "in"}, +}; + +const std::unordered_map inverseScalarMap(); + class VeloxExprConverter { public: VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser)