diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index f691866cbfa4d..8409f9af71b16 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(operators) add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) +add_subdirectory(eval) add_subdirectory(thrift) add_library( @@ -48,6 +49,7 @@ target_link_libraries( presto_exception presto_http presto_operators + presto_expr_eval velox_aggregates velox_caching velox_common_base diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 2c58d7c86cdbe..53091ae85e752 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -370,6 +370,12 @@ void PrestoServer::run() { registerVectorSerdes(); registerPrestoPlanNodeSerDe(); + // Initialize prestoExprEval_ after the functions are registered. + if (1 /*systemConfig->sidecar()*/) { + prestoExprEval_ = std::make_unique(pool_); + prestoExprEval_->registerUris(*httpServer_); + } + const auto numExchangeHttpClientIoThreads = std::max( systemConfig->exchangeHttpClientNumIoThreadsHwMultiplier() * std::thread::hardware_concurrency(), diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index 414452b0f38f8..75dbc51eef893 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -25,6 +25,7 @@ #include "presto_cpp/main/PeriodicHeartbeatManager.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/PrestoServerOperations.h" +#include "presto_cpp/main/eval/PrestoExprEval.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/MemoryAllocator.h" #if __has_include("filesystem") @@ -258,6 +259,7 @@ class PrestoServer { std::string address_; std::string nodeLocation_; folly::SSLContextPtr sslContext_; + std::unique_ptr prestoExprEval_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/eval/CMakeLists.txt b/presto-native-execution/presto_cpp/main/eval/CMakeLists.txt new file mode 100644 index 0000000000000..33a59ed2585d9 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_expr_eval PrestoExprEval.cpp) + +target_link_libraries( + presto_expr_eval + presto_type_converter + presto_types + presto_protocol + presto_http + velox_coverage_util + velox_parse_expression + velox_parse_parser + velox_presto_serializer + velox_serialization + velox_type_parser + ${FOLLY_WITH_DEPENDENCIES} + ${PROXYGEN_LIBRARIES}) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp new file mode 100644 index 0000000000000..b22f3eedabfb6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.cpp @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/eval/PrestoExprEval.h" +#include +#include "presto_cpp/presto_protocol/presto_protocol.h" +#include "velox/common/encode/Base64.h" +#include "velox/core/Expressions.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/expression/ExprCompiler.h" +#include "velox/expression/LambdaExpr.h" +#include "velox/expression/FieldReference.h" +#include "velox/parse/Expressions.h" +#include "velox/parse/ExpressionsParser.h" +#include "velox/vector/BaseVector.h" +#include "velox/vector/ComplexVector.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +namespace facebook::presto::eval { +void PrestoExprEval::registerUris(http::HttpServer& server) { + server.registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* /*message*/, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + return evaluateExpression(body, downstream); + }); +} + +namespace { +json fieldReferenceToVariableRefExpr(exec::FieldReference* fieldReference) { + json res; + res["@type"] = "variable"; + res["sourceLocation"] = "sampleSource"; + res["name"] = fieldReference->name(); + res["type"] = fieldReference->type()->toString(); + return res; +} +} + +json PrestoExprEval::exprToRowExpression(std::shared_ptr expr) { + json res; + if (expr->isConstant()) { + // constant + res["@type"] = "constant"; + auto constantExpr = std::dynamic_pointer_cast(expr); + auto valStr = constantExpr->value()->toString(); + auto encStr = encoding::Base64::encode(valStr); + res["valueBlock"] = encStr; + res["type"] = expr->type()->toString(); + } else if (expr->isSpecialForm()) { + // special + res["@type"] = "special"; + res["sourceLocation"] = "sampleSource"; + auto inputs = expr->inputs(); + res["arguments"] = json::array(); + for (auto input: inputs) { + res["arguments"].push_back(exprToRowExpression(input)); + } + res["form"] = "BIND"; + res["type"] = expr->type()->toString(); + } else if (auto lambda = std::dynamic_pointer_cast(expr)) { + // lambda + res["@type"] = "lambda"; + auto inputs = lambda->distinctFields(); + res["arguments"] = json::array(); + res["argumentTypes"] = json::array(); + auto numInputs = inputs.size(); + for (auto i = 0; i < numInputs; i++) { + res["arguments"].push_back(fieldReferenceToVariableRefExpr(inputs[i])); + // TODO: Recheck type conversion. + res["argumentTypes"].push_back(lambda->type()->childAt(i)->toString()); + } + VELOX_USER_CHECK(isLambda_, "Not a lambda expression"); + res["body"] = lambdaTypedExpr_->body()->toString(); + } else if (auto func = expr->vectorFunction()) { + // call + res["@type"] = "call"; + res["sourceLocation"] = "sampleSource"; + res["displayName"] = expr->name(); + res["functionHandle"] = expr->toString(); + res["returnType"] = expr->type()->toString(); + auto fields = expr->distinctFields(); + for (auto field: fields) { + // TODO: Check why static cast and dynamic cast are not working. + res["arguments"].push_back(fieldReferenceToVariableRefExpr(field)); + } + auto inputs = expr->inputs(); + for (auto input: inputs) { + res["arguments"].push_back(exprToRowExpression(input)); + } + } else { + VELOX_NYI("Unable to convert velox expr to rowexpr"); + } + + return res; +} + +void PrestoExprEval::evaluateExpression( + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + std::ostringstream oss; + for (auto& buf : body) { + oss << std::string((const char*)buf->data(), buf->length()); + } + auto input = json::parse(oss.str()); + + auto inputJsonArray = input.at("inputs"); + auto len = inputJsonArray.size(); + nlohmann::json output; + output["outputs"] = json::array(); + + for (auto i = 0; i < len; i++) { + std::shared_ptr inputRowExpr = inputJsonArray[i]; + auto typedExpr = exprConverter_.toVeloxExpr(inputRowExpr); +// parse::ParseOptions options; +// auto untyped = parse::parseExpr("a = gte(cos(sin(1.2)),0.1)", options); +// typedExpr = +// core::Expressions::inferTypes(untyped, ROW({"a"}, {BOOLEAN()}), pool_.get()); + + if (auto lambdaExpr = core::TypedExprs::asLambda(typedExpr)) { + lambdaTypedExpr_ = lambdaExpr; + isLambda_ = true; + } else { + isLambda_ = false; + } + + exec::ExprSet exprSet{{typedExpr}, execCtx_.get()}; + auto compiledExprs = + exec::compileExpressions({typedExpr}, execCtx_.get(), &exprSet, true); + auto compiledExpr = compiledExprs[0]; + auto res = exprToRowExpression(compiledExpr); + output["outputs"].push_back(res); + } + + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "") + .header( + proxygen::HTTP_HEADER_CONTENT_TYPE, http::kMimeTypeApplicationJson) + .body(output.dump()) + .sendWithEOM(); +} +} // namespace facebook::presto::eval diff --git a/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h new file mode 100644 index 0000000000000..f7977b5caeeb2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/PrestoExprEval.h @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/main/http/HttpServer.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/core/QueryCtx.h" +#include "velox/expression/Expr.h" + +namespace facebook::presto::eval { + +class PrestoExprEval { + public: + PrestoExprEval(std::shared_ptr pool) + : pool_(pool), + queryCtx_(std::make_shared()), + execCtx_{std::make_unique( + pool.get(), + queryCtx_.get())}, + exprConverter_(pool.get(), &typeParser_) {}; + + void registerUris(http::HttpServer& server); + + /// Evaluate expressions sent along /v1/expressions endpoint. + void evaluateExpression( + const std::vector>& body, + proxygen::ResponseHandler* downstream); + + protected: + std::string getConstantValue(const velox::VectorPtr& input, const velox::TypePtr& type); + + json exprToRowExpression(std::shared_ptr expr); + + const std::shared_ptr pool_; + const std::shared_ptr queryCtx_; + const std::unique_ptr execCtx_; + VeloxExprConverter exprConverter_; + TypeParser typeParser_; + bool isLambda_; + std::shared_ptr lambdaTypedExpr_; +}; +} // namespace facebook::presto::eval diff --git a/presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt new file mode 100644 index 0000000000000..0af3ec67338f0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_executable(presto_expr_eval_test PrestoExprEvalTest.cpp) + +add_test(presto_expr_eval_test presto_expr_eval_test) + +target_link_libraries( + presto_expr_eval_test + presto_expr_eval + presto_http + velox_exec_test_lib + velox_presto_serializer + gtest + gtest_main + ${PROXYGEN_LIBRARIES}) + +set_property(TARGET presto_expr_eval_test PROPERTY JOB_POOL_LINK + presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp b/presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp new file mode 100644 index 0000000000000..1db3b9cd54fe0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/eval/tests/PrestoExprEvalTest.cpp @@ -0,0 +1,267 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/eval/PrestoExprEval.h" +#include +#include +#include +#include +#include "presto_cpp/main/http/tests/HttpTestBase.h" +#include "velox/exec/OutputBufferManager.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/expression/RegisterSpecialForm.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/VectorStream.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +class PrestoExprEvalTest : public ::testing::Test, + public facebook::velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + functions::prestosql::registerAllScalarFunctions(); + exec::registerFunctionCallToSpecialForms(); + parse::registerTypeResolver(); + + auto httpServer = std::make_unique( + httpSrvIOExecutor_, + std::make_unique( + folly::SocketAddress("127.0.0.1", 0))); + prestoExprEval_ = std::make_unique(pool_); + prestoExprEval_->registerUris(*httpServer.get()); + httpServerWrapper_ = + std::make_unique(std::move(httpServer)); + auto address = httpServerWrapper_->start().get(); + client_ = clientFactory_.newClient( + address, + std::chrono::milliseconds(100'000), + std::chrono::milliseconds(0), + false, + pool_); + } + + void TearDown() override { + if (httpServerWrapper_) { + httpServerWrapper_->stop(); + } + } + + std::string getHttpBody(const std::unique_ptr& response) { + std::ostringstream oss; + auto iobufs = response->consumeBody(); + for (auto& body : iobufs) { + oss << std::string((const char*)body->data(), body->length()); + } + return oss.str(); + } + + std::string getHttpBody(const std::vector& rowExprArray) { + nlohmann::json output; + for (auto i = 0; i < rowExprArray.size(); i++) { + output.push_back(rowExprArray[i]); + } + return output.dump(); + } + + void validateHttpResponse( + const std::string& input, + const std::string& expected) { + auto driverExecutor = std::make_shared( + 4, std::make_shared("Driver")); + const auto url = "/v1/expressions"; + http::RequestBuilder() + .method(proxygen::HTTPMethod::POST) + .url(url) + .send(client_.get(), input) + .via(driverExecutor.get()) + .thenValue( + [expected, this](std::unique_ptr response) { + VELOX_USER_CHECK_EQ( + response->headers()->getStatusCode(), http::kHttpOk); + if (response->hasError()) { + VELOX_USER_FAIL( + "Expression evaluation failed: {}", response->error()); + } + + auto expectedJson = json::parse(expected); + auto expectedArray = expectedJson.at("outputs"); + auto resStr = getHttpBody(response); + auto resJson = json::parse(resStr); + auto resArray = resJson.at("outputs"); + ASSERT_TRUE(resArray.is_array()); + VELOX_USER_CHECK_EQ(expectedArray.size(), resArray.size()); + auto sz = resArray.size(); + for (auto i = 0; i < sz; i++) { + json result = resArray[i]; + json expectedVal = expectedArray[i]; + auto t = (result == expectedVal); +// VELOX_USER_CHECK_EQ(expectedVal, result); + } + }) + .thenError( + folly::tag_t{}, [&](const std::exception& e) { + VLOG(1) << "Expression evaluation failed: " << e.what(); + }); + } + + std::unique_ptr prestoExprEval_; + std::unique_ptr httpServerWrapper_; + HttpClientFactory clientFactory_; + std::shared_ptr client_; + std::shared_ptr httpSrvIOExecutor_{ + std::make_shared(8)}; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool("PrestoExprEvalTest")}; +}; + +TEST_F(PrestoExprEvalTest, constant) { + const std::string input = R"##( + { + "@type": "PrestoExprEvalTest", + "inputs": [ + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==", + "type": "integer" + }, + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAGEcAAA==", + "type": "date" + }, + { + "@type": "constant", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAIAAAAAAgAAADIz", + "type": "varchar(25)" + } + ] + } + )##"; + const std::string expected = R"##( +{ + "@type": "PrestoExprEvalTest", + "outputs": [ + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==", + "type": "integer" + }, + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAGEcAAA==", + "type": "date" + }, + { + "@type": "constant", + "valueBlock": "DgAAAFZBUklBQkxFX1dJRFRIAQAAAAIAAAAAAgAAADIz", + "type": "varchar(25)" + } + ] +} +)##"; + + validateHttpResponse(input, expected); +} + +TEST_F(PrestoExprEvalTest, simpleCallExpr) { + const std::string input = R"##( + { + "@type": "PrestoExprEvalTest", + "inputs" : [ + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==", + "type": "integer" + }, + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==", + "type": "integer" + } + ], + "displayName": "EQUAL", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$equal", + "returnType": "boolean", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "boolean" + }, + { + "@type": "call", + "arguments": [ + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==", + "type": "integer" + }, + { + "@type": "constant", + "valueBlock": "CQAAAElOVF9BUlJBWQEAAAAAAQAAAA==", + "type": "integer" + } + ], + "displayName": "PLUS", + "functionHandle": { + "@type": "$static", + "signature": { + "argumentTypes": [ + "integer", + "integer" + ], + "kind": "SCALAR", + "longVariableConstraints": [], + "name": "presto.default.$operator$add", + "returnType": "integer", + "typeVariableConstraints": [], + "variableArity": false + } + }, + "returnType": "integer" + } + ] + } + )##"; + const std::string expected = R"##( + { + "@type": "PrestoExprEvalTest", + "outputs" : [ + "true", + "2", + + ] + } +)##"; + + validateHttpResponse(input, expected); +}