Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[native] Add native plan checker and native endpoint for Velox plan conversion #23596

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public Set<Node> getAllNodes()
return ImmutableSet.<Node>builder()
.addAll(getWorkerNodes())
.addAll(nodeManager.getCoordinators())
.addAll(nodeManager.getCoordinatorSidecars())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,6 @@ public void installPlugin(Plugin plugin)
log.info("Registering node status notification provider %s", nodeStatusNotificationProviderFactory.getName());
nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory);
}

for (PlanCheckerProviderFactory planCheckerProviderFactory : plugin.getPlanCheckerProviderFactories()) {
log.info("Registering plan checker provider factory %s", planCheckerProviderFactory.getName());
planCheckerProviderManager.addPlanCheckerProviderFactory(planCheckerProviderFactory);
}
}

public void installCoordinatorPlugin(CoordinatorPlugin plugin)
Expand All @@ -372,6 +367,11 @@ public void installCoordinatorPlugin(CoordinatorPlugin plugin)
log.info("Registering system session property provider factory %s", providerFactory.getName());
metadata.getSessionPropertyManager().addSessionPropertyProviderFactory(providerFactory);
}

for (PlanCheckerProviderFactory planCheckerProviderFactory : plugin.getPlanCheckerProviderFactories()) {
log.info("Registering plan checker provider factory %s", planCheckerProviderFactory.getName());
planCheckerProviderManager.addPlanCheckerProviderFactory(planCheckerProviderFactory);
}
}

private URLClassLoader buildClassLoader(String plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.facebook.airlift.json.smile.SmileModule;
import com.facebook.airlift.log.LogJmxModule;
import com.facebook.airlift.log.Logger;
import com.facebook.airlift.node.NodeInfo;
import com.facebook.airlift.node.NodeModule;
import com.facebook.airlift.tracetoken.TraceTokenModule;
import com.facebook.drift.server.DriftServer;
Expand All @@ -39,9 +40,11 @@
import com.facebook.presto.execution.warnings.WarningCollectorModule;
import com.facebook.presto.metadata.Catalog;
import com.facebook.presto.metadata.CatalogManager;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.metadata.DiscoveryNodeManager;
import com.facebook.presto.metadata.InternalNodeManager;
import com.facebook.presto.metadata.StaticCatalogStore;
import com.facebook.presto.metadata.StaticFunctionNamespaceStore;
import com.facebook.presto.nodeManager.PluginNodeManager;
import com.facebook.presto.security.AccessControlManager;
import com.facebook.presto.security.AccessControlModule;
import com.facebook.presto.server.security.PasswordAuthenticatorManager;
Expand Down Expand Up @@ -179,8 +182,12 @@ public void run()
injector.getInstance(TracerProviderManager.class).loadTracerProvider();
injector.getInstance(NodeStatusNotificationManager.class).loadNodeStatusNotificationProvider();
injector.getInstance(GracefulShutdownHandler.class).loadNodeStatusNotification();
injector.getInstance(PlanCheckerProviderManager.class).loadPlanCheckerProviders();
injector.getInstance(SessionPropertyManager.class).loadSessionPropertyProviders();
PlanCheckerProviderManager planCheckerProviderManager = injector.getInstance(PlanCheckerProviderManager.class);
InternalNodeManager nodeManager = injector.getInstance(DiscoveryNodeManager.class);
NodeInfo nodeInfo = injector.getInstance(NodeInfo.class);
PluginNodeManager pluginNodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment());
planCheckerProviderManager.loadPlanCheckerProviders(pluginNodeManager);

startAssociatedProcesses(injector);

injector.getInstance(Announcer.class).start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,15 @@ public void validatePlanFragment(PlanFragment planFragment, Session session, Met
checkers.get(Stage.FRAGMENT).forEach(checker -> checker.validateFragment(planFragment, session, metadata, warningCollector));
for (PlanCheckerProvider provider : planCheckerProviderManager.getPlanCheckerProviders()) {
for (com.facebook.presto.spi.plan.PlanChecker checker : provider.getFragmentPlanCheckers()) {
checker.validateFragment(toSimplePlanFragment(planFragment), warningCollector);
checker.validateFragment(new SimplePlanFragment(
planFragment.getId(),
planFragment.getRoot(),
planFragment.getVariables(),
planFragment.getPartitioning(),
planFragment.getTableScanSchedulingOrder(),
planFragment.getPartitioningScheme(),
planFragment.getStageExecutionDescriptor(),
planFragment.isOutputTableWriterFragment()), warningCollector);
}
}
}
Expand All @@ -126,17 +134,4 @@ private enum Stage
{
INTERMEDIATE, FINAL, FRAGMENT
}

private static SimplePlanFragment toSimplePlanFragment(PlanFragment planFragment)
{
return new SimplePlanFragment(
planFragment.getId(),
planFragment.getRoot(),
planFragment.getVariables(),
planFragment.getPartitioning(),
planFragment.getTableScanSchedulingOrder(),
planFragment.getPartitioningScheme(),
planFragment.getStageExecutionDescriptor(),
planFragment.isOutputTableWriterFragment());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.sql.planner.sanity;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.plan.PlanCheckerProvider;
import com.facebook.presto.spi.plan.PlanCheckerProviderContext;
import com.facebook.presto.spi.plan.PlanCheckerProviderFactory;
Expand All @@ -40,15 +41,15 @@ public class PlanCheckerProviderManager
private static final Logger log = Logger.get(PlanCheckerProviderManager.class);
private static final String PLAN_CHECKER_PROVIDER_NAME = "plan-checker-provider.name";

private final PlanCheckerProviderContext planCheckerProviderContext;
private final SimplePlanFragmentSerde simplePlanFragmentSerde;
private final Map<String, PlanCheckerProviderFactory> providerFactories = new ConcurrentHashMap<>();
private final CopyOnWriteArrayList<PlanCheckerProvider> providers = new CopyOnWriteArrayList<>();
private final File configDirectory;

@Inject
public PlanCheckerProviderManager(SimplePlanFragmentSerde simplePlanFragmentSerde, PlanCheckerProviderManagerConfig config)
{
this.planCheckerProviderContext = new PlanCheckerProviderContext(requireNonNull(simplePlanFragmentSerde, "planNodeSerde is null"));
this.simplePlanFragmentSerde = requireNonNull(simplePlanFragmentSerde, "planNodeSerde is null");
requireNonNull(config, "config is null");
this.configDirectory = requireNonNull(config.getPlanCheckerConfigurationDir(), "configDirectory is null");
}
Expand All @@ -61,9 +62,11 @@ public void addPlanCheckerProviderFactory(PlanCheckerProviderFactory planChecker
}
}

public void loadPlanCheckerProviders()
public void loadPlanCheckerProviders(NodeManager nodeManager)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can NodeManager be injected instead along with SimplePlanFragmentSerde?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to bind the PluginNodeManager and @tdcmeehan had requested to not bind, but just create when needed. I agree since we already bind an internal node manager, this is just created from that to pass to plugins.

throws IOException
{
PlanCheckerProviderContext planCheckerProviderContext = new PlanCheckerProviderContext(simplePlanFragmentSerde, nodeManager);

for (File file : listFiles(configDirectory)) {
if (file.isFile() && file.getName().endsWith(".properties")) {
// unlike function namespaces and connectors, we don't have a concept of catalog
Expand All @@ -75,7 +78,7 @@ public void loadPlanCheckerProviders()
file.getAbsoluteFile(),
PLAN_CHECKER_PROVIDER_NAME);
String planCheckerProviderName = properties.remove(PLAN_CHECKER_PROVIDER_NAME);
log.info("-- Loading plan checker provider %s--", planCheckerProviderName);
log.info("-- Loading plan checker provider [%s] --", planCheckerProviderName);
PlanCheckerProviderFactory providerFactory = providerFactories.get(planCheckerProviderName);
checkState(providerFactory != null,
"No planCheckerProviderFactory found for '%s'. Available factories were %s", planCheckerProviderName, providerFactories.keySet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.spi.plan.PlanCheckerProviderFactory;
import com.facebook.presto.spi.plan.SimplePlanFragment;
import com.facebook.presto.sql.planner.plan.JsonCodecSimplePlanFragmentSerde;
import com.facebook.presto.testing.TestingNodeManager;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

Expand All @@ -40,7 +41,7 @@ public void testLoadPlanCheckerProviders()
.setPlanCheckerConfigurationDir(new File("src/test/resources/plan-checkers"));
PlanCheckerProviderManager planCheckerProviderManager = new PlanCheckerProviderManager(new JsonCodecSimplePlanFragmentSerde(JsonCodec.jsonCodec(SimplePlanFragment.class)), planCheckerProviderManagerConfig);
planCheckerProviderManager.addPlanCheckerProviderFactory(new TestingPlanCheckerProviderFactory());
planCheckerProviderManager.loadPlanCheckerProviders();
planCheckerProviderManager.loadPlanCheckerProviders(new TestingNodeManager());
assertEquals(planCheckerProviderManager.getPlanCheckerProviders(), ImmutableList.of(TESTING_PLAN_CHECKER_PROVIDER));
}

Expand All @@ -51,7 +52,7 @@ public void testLoadUnregisteredPlanCheckerProvider()
PlanCheckerProviderManagerConfig planCheckerProviderManagerConfig = new PlanCheckerProviderManagerConfig()
.setPlanCheckerConfigurationDir(new File("src/test/resources/plan-checkers"));
PlanCheckerProviderManager planCheckerProviderManager = new PlanCheckerProviderManager(new JsonCodecSimplePlanFragmentSerde(JsonCodec.jsonCodec(SimplePlanFragment.class)), planCheckerProviderManagerConfig);
planCheckerProviderManager.loadPlanCheckerProviders();
planCheckerProviderManager.loadPlanCheckerProviders(new TestingNodeManager());
}

public static class TestingPlanCheckerProviderFactory
Expand Down
1 change: 1 addition & 0 deletions presto-native-execution/presto_cpp/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ target_link_libraries(
presto_function_metadata
presto_http
presto_operators
presto_velox_conversion
velox_aggregates
velox_caching
velox_common_base
Expand Down
22 changes: 22 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "presto_cpp/main/operators/UnsafeRowExchangeSource.h"
#include "presto_cpp/main/types/FunctionMetadata.h"
#include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h"
#include "presto_cpp/main/types/VeloxPlanConversion.h"
#include "velox/common/base/Counters.h"
#include "velox/common/base/StatsReporter.h"
#include "velox/common/caching/CacheTTLController.h"
Expand Down Expand Up @@ -478,6 +479,9 @@ void PrestoServer::run() {

pool_ =
velox::memory::MemoryManager::getInstance()->addLeafPool("PrestoServer");
nativeWorkerPool_ = velox::memory::MemoryManager::getInstance()->addLeafPool(
"PrestoNativeWorker");

taskManager_ = std::make_unique<TaskManager>(
driverExecutor_.get(), httpSrvCpuExecutor_.get(), spillerExecutor_.get());

Expand Down Expand Up @@ -1475,6 +1479,24 @@ void PrestoServer::registerSidecarEndpoints() {
proxygen::ResponseHandler* downstream) {
http::sendOkResponse(downstream, getFunctionsMetadata());
});
httpServer_->registerPost(
"/v1/velox/plan",
[server = this](
proxygen::HTTPMessage* message,
const std::vector<std::unique_ptr<folly::IOBuf>>& body,
proxygen::ResponseHandler* downstream) {
std::string planFragmentJson = util::extractMessageBody(body);
protocol::PlanConversionResponse response = prestoToVeloxPlanConversion(
planFragmentJson,
server->nativeWorkerPool_.get(),
server->getVeloxPlanValidator());
if (response.failures.empty()) {
http::sendOkResponse(downstream, json(response));
} else {
http::sendResponse(
downstream, json(response), http::kHttpUnprocessableContent);
}
});
}

protocol::NodeStatus PrestoServer::fetchNodeStatus() {
Expand Down
1 change: 1 addition & 0 deletions presto-native-execution/presto_cpp/main/PrestoServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class PrestoServer {
std::unique_ptr<Announcer> announcer_;
std::unique_ptr<PeriodicHeartbeatManager> heartbeatManager_;
std::shared_ptr<velox::memory::MemoryPool> pool_;
std::shared_ptr<velox::memory::MemoryPool> nativeWorkerPool_;
std::unique_ptr<TaskManager> taskManager_;
std::unique_ptr<TaskResource> taskResource_;
std::atomic<NodeState> nodeState_{NodeState::kActive};
Expand Down
8 changes: 1 addition & 7 deletions presto-native-execution/presto_cpp/main/TaskResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl(
httpSrvCpuExecutor_,
[this, &body, taskId, createOrUpdateFunc]() {
const auto startProcessCpuTimeNs = util::getProcessCpuTimeNs();

// TODO Avoid copy
std::ostringstream oss;
for (auto& buf : body) {
oss << std::string((const char*)buf->data(), buf->length());
}
std::string updateJson = oss.str();
std::string updateJson = util::extractMessageBody(body);

std::unique_ptr<protocol::TaskInfo> taskInfo;
try {
Expand Down
9 changes: 9 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,13 @@ void installSignalHandler() {
#endif // __APPLE__
}

std::string extractMessageBody(
const std::vector<std::unique_ptr<folly::IOBuf>>& body) {
// TODO Avoid copy
std::ostringstream oss;
for (auto& buf : body) {
oss << std::string((const char*)buf->data(), buf->length());
}
return oss.str();
}
} // namespace facebook::presto::util
3 changes: 3 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ long getProcessCpuTimeNs();
/// context such as the queryId.
void installSignalHandler();

std::string extractMessageBody(
const std::vector<std::unique_ptr<folly::IOBuf>>& body);

} // namespace facebook::presto::util
7 changes: 2 additions & 5 deletions presto-native-execution/presto_cpp/main/http/HttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <folly/synchronization/Latch.h>
#include <velox/common/base/Exceptions.h>
#include "presto_cpp/main/common/Configs.h"
#include "presto_cpp/main/common/Utils.h"
#include "presto_cpp/main/http/HttpClient.h"

namespace facebook::presto::http {
Expand Down Expand Up @@ -169,11 +170,7 @@ HttpResponse::nextAllocationSize(uint64_t dataLength) const {
std::string HttpResponse::dumpBodyChain() const {
std::string responseBody;
if (!bodyChain_.empty()) {
std::ostringstream oss;
for (const auto& buf : bodyChain_) {
oss << std::string((const char*)buf->data(), buf->length());
}
responseBody = oss.str();
responseBody = util::extractMessageBody(bodyChain_);
}
return responseBody;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const uint16_t kHttpNoContent = 204;
const uint16_t kHttpBadRequest = 400;
const uint16_t kHttpUnauthorized = 401;
const uint16_t kHttpNotFound = 404;
const uint16_t kHttpUnprocessableContent = 422;
const uint16_t kHttpInternalServerError = 500;

const char kMimeTypeApplicationJson[] = "application/json";
Expand Down
44 changes: 28 additions & 16 deletions presto-native-execution/presto_cpp/main/http/HttpServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,7 @@ void sendOkResponse(proxygen::ResponseHandler* downstream) {
}

void sendOkResponse(proxygen::ResponseHandler* downstream, const json& body) {
// nlohmann::json throws when it finds invalid UTF-8 characters. In that case
// the server will crash. We handle such situation here and generate body
// replacing the faulty UTF-8 sequences.
std::string messageBody;
try {
messageBody = body.dump();
} catch (const std::exception& e) {
messageBody =
body.dump(-1, ' ', false, nlohmann::detail::error_handler_t::replace);
LOG(WARNING) << "Failed to serialize json to string. "
"Will retry with 'replace' option. "
"Json Dump:\n"
<< messageBody;
}

sendOkResponse(downstream, messageBody);
sendResponse(downstream, body, http::kHttpOk);
}

void sendOkResponse(
Expand Down Expand Up @@ -75,6 +60,33 @@ void sendErrorResponse(
.sendWithEOM();
}

void sendResponse(
proxygen::ResponseHandler* downstream,
const json& body,
uint16_t status) {
// nlohmann::json throws when it finds invalid UTF-8 characters. In that case
// the server will crash. We handle such situation here and generate body
// replacing the faulty UTF-8 sequences.
std::string messageBody;
try {
messageBody = body.dump();
} catch (const std::exception& e) {
messageBody =
body.dump(-1, ' ', false, nlohmann::detail::error_handler_t::replace);
LOG(WARNING) << "Failed to serialize json to string. "
"Will retry with 'replace' option. "
"Json Dump:\n"
<< messageBody;
}

proxygen::ResponseBuilder(downstream)
.status(status, "")
.header(
proxygen::HTTP_HEADER_CONTENT_TYPE, http::kMimeTypeApplicationJson)
.body(messageBody)
.sendWithEOM();
}

HttpConfig::HttpConfig(const folly::SocketAddress& address, bool reusePort)
: address_(address), reusePort_(reusePort) {}

Expand Down
5 changes: 5 additions & 0 deletions presto-native-execution/presto_cpp/main/http/HttpServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ void sendErrorResponse(
const std::string& error = "",
uint16_t status = http::kHttpInternalServerError);

void sendResponse(
proxygen::ResponseHandler* downstream,
const json& body,
uint16_t status);

class AbstractRequestHandler : public proxygen::RequestHandler {
public:
void onRequest(
Expand Down
Loading
Loading