From ecfd5b83bf0b0830b004843a7baf4ff63d5ef728 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 25 Aug 2023 18:04:22 -0700 Subject: [PATCH] feat: add api for loading plans of all types (#80) features: * supports reading/writing all of the major formats caveats: * only read/writes to filenames so in-memory use should use other interfaces * does not support compression * does not have a zero copy interface --- include/substrait/common/Io.h | 60 +++++++++ src/substrait/common/CMakeLists.txt | 13 +- src/substrait/common/Io.cpp | 77 +++++++++++ src/substrait/common/tests/CMakeLists.txt | 10 ++ src/substrait/common/tests/IoTest.cpp | 120 ++++++++++++++++++ src/substrait/textplan/CMakeLists.txt | 4 +- src/substrait/textplan/StringManipulation.cpp | 7 +- src/substrait/textplan/StringManipulation.h | 2 + src/substrait/textplan/SymbolTablePrinter.cpp | 25 +++- .../textplan/converter/CMakeLists.txt | 8 +- .../textplan/converter/LoadBinary.cpp | 46 ++++--- src/substrait/textplan/converter/LoadBinary.h | 43 ++----- src/substrait/textplan/converter/README.md | 4 +- .../textplan/converter/SaveBinary.cpp | 113 +++++++++++++++++ src/substrait/textplan/converter/SaveBinary.h | 33 +++++ src/substrait/textplan/converter/Tool.cpp | 15 +-- .../tests/BinaryToTextPlanConversionTest.cpp | 15 ++- src/substrait/textplan/parser/CMakeLists.txt | 6 +- src/substrait/textplan/parser/LoadText.cpp | 25 ++++ src/substrait/textplan/parser/LoadText.h | 17 +++ src/substrait/textplan/parser/ParseText.cpp | 2 +- .../parser/SubstraitPlanRelationVisitor.cpp | 10 +- .../textplan/tests/RoundtripTest.cpp | 10 +- 23 files changed, 574 insertions(+), 91 deletions(-) create mode 100644 include/substrait/common/Io.h create mode 100644 src/substrait/common/Io.cpp create mode 100644 src/substrait/common/tests/IoTest.cpp create mode 100644 src/substrait/textplan/converter/SaveBinary.cpp create mode 100644 src/substrait/textplan/converter/SaveBinary.h create mode 100644 src/substrait/textplan/parser/LoadText.cpp create mode 100644 src/substrait/textplan/parser/LoadText.h diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h new file mode 100644 index 00000000..a53697d8 --- /dev/null +++ b/include/substrait/common/Io.h @@ -0,0 +1,60 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "absl/status/statusor.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait { + +/* + * \brief The four different ways plans can be represented on disk. + */ +enum class PlanFileFormat { + kBinary = 0, + kJson = 1, + kProtoText = 2, + kText = 3, +}; + +/* + * \brief Loads a Substrait plan of any format from the given file. + * + * loadPlan determines which file type the specified file is and then calls + * the appropriate load/parse method to consume it preserving any error + * messages. + * + * This will load the plan into memory and then convert it consuming twice the + * amount of memory that it consumed on disk. + * + * \param input_filename The filename containing the plan to convert. + * \return If loading was successful, returns a plan. If loading was not + * successful this is a status containing a list of parse errors in the status's + * message. + */ +absl::StatusOr<::substrait::proto::Plan> loadPlan( + std::string_view input_filename); + +/* + * \brief Writes the provided plan to disk. + * + * savePlan writes the provided plan in the specified format to the specified + * location. + * + * This routine will consume more memory during the conversion to the text + * format as the original plan as well as the annotated parse tree will need to + * reside in memory during the process. + * + * \param plan + * \param output_filename + * \param format + * \return + */ +absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileFormat format); + +} // namespace io::substrait diff --git a/src/substrait/common/CMakeLists.txt b/src/substrait/common/CMakeLists.txt index 846d077f..dc05a11e 100644 --- a/src/substrait/common/CMakeLists.txt +++ b/src/substrait/common/CMakeLists.txt @@ -1,9 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 add_library(substrait_common Exceptions.cpp) - target_link_libraries(substrait_common fmt::fmt-header-only) +add_library(substrait_io Io.cpp) +add_dependencies( + substrait_io + substrait_proto + substrait_textplan_converter + substrait_textplan_loader + fmt::fmt-header-only + absl::status + absl::statusor) +target_link_libraries(substrait_io substrait_proto substrait_textplan_converter + substrait_textplan_loader absl::status absl::statusor) + if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp new file mode 100644 index 00000000..af06066c --- /dev/null +++ b/src/substrait/common/Io.cpp @@ -0,0 +1,77 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/converter/LoadBinary.h" +#include "substrait/textplan/converter/SaveBinary.h" +#include "substrait/textplan/parser/LoadText.h" + +namespace io::substrait { + +namespace { + +const std::regex kIsJson( + R"(("extensionUris"|"extension_uris"|"extensions"|"relations"))"); +const std::regex kIsProtoText( + R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))"); +const std::regex kIsText( + R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)"); + +PlanFileFormat detectFormat(std::string_view content) { + if (std::regex_search(content.begin(), content.end(), kIsJson)) { + return PlanFileFormat::kJson; + } + if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { + return PlanFileFormat::kProtoText; + } + if (std::regex_search(content.begin(), content.end(), kIsText)) { + return PlanFileFormat::kText; + } + return PlanFileFormat::kBinary; +} + +} // namespace + +absl::StatusOr<::substrait::proto::Plan> loadPlan( + std::string_view input_filename) { + auto contentOrError = textplan::readFromFile(input_filename.data()); + if (!contentOrError.ok()) { + return contentOrError.status(); + } + + auto encoding = detectFormat(*contentOrError); + absl::StatusOr<::substrait::proto::Plan> planOrError; + switch (encoding) { + case PlanFileFormat::kBinary: + return textplan::loadFromBinary(*contentOrError); + case PlanFileFormat::kJson: + return textplan::loadFromJson(*contentOrError); + case PlanFileFormat::kProtoText: + return textplan::loadFromProtoText(*contentOrError); + case PlanFileFormat::kText: + return textplan::loadFromText(*contentOrError); + } +} + +absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileFormat format) { + switch (format) { + case PlanFileFormat::kBinary: + return textplan::savePlanToBinary(plan, output_filename); + case PlanFileFormat::kJson: + return textplan::savePlanToJson(plan, output_filename); + case PlanFileFormat::kProtoText: + return textplan::savePlanToProtoText(plan, output_filename); + case PlanFileFormat::kText: + return textplan::savePlanToText(plan, output_filename); + } + return absl::UnimplementedError("Unexpected format requested."); +} + +} // namespace io::substrait diff --git a/src/substrait/common/tests/CMakeLists.txt b/src/substrait/common/tests/CMakeLists.txt index 36c5f160..88b63b8f 100644 --- a/src/substrait/common/tests/CMakeLists.txt +++ b/src/substrait/common/tests/CMakeLists.txt @@ -9,3 +9,13 @@ add_test_case( substrait_common gtest gtest_main) + +add_test_case( + substrait_io_test + SOURCES + IoTest.cpp + EXTRA_LINK_LIBS + substrait_io + protobuf-matchers + gtest + gtest_main) diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp new file mode 100644 index 00000000..183594bd --- /dev/null +++ b/src/substrait/common/tests/IoTest.cpp @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include + +#include +#include +#include +#include + +using ::protobuf_matchers::EqualsProto; +using ::protobuf_matchers::Partially; + +namespace io::substrait { + +namespace { + +constexpr const char* planFileEncodingToString(PlanFileFormat e) noexcept { + switch (e) { + case PlanFileFormat::kBinary: + return "kBinary"; + case PlanFileFormat::kJson: + return "kJson"; + case PlanFileFormat::kProtoText: + return "kProtoText"; + case PlanFileFormat::kText: + return "kText"; + } + return "IMPOSSIBLE"; +} + +} // namespace + +class IoTest : public ::testing::Test {}; + +TEST_F(IoTest, LoadMissingFile) { + auto result = ::io::substrait::loadPlan("non-existent-file"); + ASSERT_FALSE(result.ok()); + ASSERT_THAT( + result.status().message(), + ::testing::ContainsRegex("Failed to open file non-existent-file")); +} + +class SaveAndLoadTestFixture : public ::testing::TestWithParam { + public: + void SetUp() override { + testFileDirectory_ = std::filesystem::temp_directory_path() / + std::filesystem::path("my_temp_dir"); + + if (!std::filesystem::create_directory(testFileDirectory_)) { + ASSERT_TRUE(false) << "Failed to create temporary directory."; + testFileDirectory_.clear(); + } + } + + void TearDown() override { + if (!testFileDirectory_.empty()) { + std::error_code err; + std::filesystem::remove_all(testFileDirectory_, err); + ASSERT_FALSE(err) << err.message(); + } + } + + static std::string makeTempFileName() { + static int tempFileNum = 0; + return "testfile" + std::to_string(++tempFileNum); + } + + protected: + std::string testFileDirectory_; +}; + +TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { + auto tempFilename = testFileDirectory_ + "/" + makeTempFileName(); + PlanFileFormat encoding = GetParam(); + + ::substrait::proto::Plan plan; + auto root = plan.add_relations()->mutable_root(); + auto read = root->mutable_input()->mutable_read(); + read->mutable_common()->mutable_direct(); + read->mutable_named_table()->add_names("table_name"); + auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); + ASSERT_TRUE(status.ok()) << "Save failed.\n" << status; + + auto result = ::io::substrait::loadPlan(tempFilename); + ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status(); + ASSERT_THAT( + *result, + Partially(EqualsProto<::substrait::proto::Plan>( + R"(relations { + root { + input { + read { + common { + direct { + } + } + named_table { + names: "table_name" + } + } + } + } + })"))); +} + +INSTANTIATE_TEST_SUITE_P( + SaveAndLoadTests, + SaveAndLoadTestFixture, + testing::Values( + PlanFileFormat::kBinary, + PlanFileFormat::kJson, + PlanFileFormat::kProtoText, + PlanFileFormat::kText), + [](const testing::TestParamInfo& info) { + return planFileEncodingToString(info.param); + }); + +} // namespace io::substrait diff --git a/src/substrait/textplan/CMakeLists.txt b/src/substrait/textplan/CMakeLists.txt index 1e8b7dbb..eda37d58 100644 --- a/src/substrait/textplan/CMakeLists.txt +++ b/src/substrait/textplan/CMakeLists.txt @@ -20,10 +20,10 @@ add_library(error_listener SubstraitErrorListener.cpp SubstraitErrorListener.h) add_library(parse_result ParseResult.cpp ParseResult.h) -add_dependencies(symbol_table substrait_proto substrait_common +add_dependencies(symbol_table substrait_proto substrait_common absl::strings fmt::fmt-header-only) -target_link_libraries(symbol_table fmt::fmt-header-only +target_link_libraries(symbol_table fmt::fmt-header-only absl::strings substrait_textplan_converter) # Provide access to the generated protobuffer headers hierarchy. diff --git a/src/substrait/textplan/StringManipulation.cpp b/src/substrait/textplan/StringManipulation.cpp index eac3c56a..cb11e53a 100644 --- a/src/substrait/textplan/StringManipulation.cpp +++ b/src/substrait/textplan/StringManipulation.cpp @@ -2,15 +2,18 @@ #include "StringManipulation.h" +#include +#include +#include +#include + namespace io::substrait::textplan { -// Yields true if the string 'haystack' starts with the string 'needle'. bool startsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(0, needle.size()) == needle; } -// Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; diff --git a/src/substrait/textplan/StringManipulation.h b/src/substrait/textplan/StringManipulation.h index 9c24418f..8edf7ea5 100644 --- a/src/substrait/textplan/StringManipulation.h +++ b/src/substrait/textplan/StringManipulation.h @@ -2,7 +2,9 @@ #pragma once +#include #include +#include namespace io::substrait::textplan { diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 19f96188..1b460624 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) { auto subtype = ANY_CAST(SourceType, info.subtype); switch (subtype) { case SourceType::kNamedTable: { - auto table = - ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob); + if (info.blob.has_value()) { + // We are using the proto as is in lieu of a disciplined structure. + auto table = ANY_CAST( + const ::substrait::proto::ReadRel_NamedTable*, info.blob); + text << "source named_table " << info.name << " {\n"; + text << " names = [\n"; + for (const auto& name : table->names()) { + text << " \"" << name << "\",\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + // We are using the new style data structure. text << "source named_table " << info.name << " {\n"; text << " names = [\n"; - for (const auto& name : table->names()) { - text << " \"" << name << "\",\n"; + for (const auto& sym : + symbolTable.lookupSymbolsByLocation(info.location)) { + if (sym->type == SymbolType::kSourceDetail) { + text << " \"" << sym->name << "\",\n"; + } } text << " ]\n"; text << "}\n"; hasPreviousText = true; + break; } case SourceType::kLocalFiles: { diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 5e1cf982..5ee27539 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -13,6 +13,8 @@ set(TEXTPLAN_SRCS PlanPrinterVisitor.h LoadBinary.cpp LoadBinary.h + SaveBinary.cpp + SaveBinary.h ParseBinary.cpp ParseBinary.h) @@ -22,10 +24,14 @@ target_link_libraries( substrait_textplan_converter substrait_common substrait_expression + substrait_io substrait_proto symbol_table error_listener - date::date) + date::date + fmt::fmt-header-only + absl::status + absl::statusor) if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) diff --git a/src/substrait/textplan/converter/LoadBinary.cpp b/src/substrait/textplan/converter/LoadBinary.cpp index 787ab054..c5d9f4ce 100644 --- a/src/substrait/textplan/converter/LoadBinary.cpp +++ b/src/substrait/textplan/converter/LoadBinary.cpp @@ -2,20 +2,20 @@ #include "substrait/textplan/converter/LoadBinary.h" +#include +#include #include #include #include - #include #include -#include #include #include #include #include -#include "substrait/common/Exceptions.h" #include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" namespace io::substrait::textplan { @@ -39,24 +39,23 @@ class StringErrorCollector : public google::protobuf::io::ErrorCollector { } // namespace -std::string readFromFile(std::string_view msgPath) { +absl::StatusOr readFromFile(std::string_view msgPath) { std::ifstream textFile(std::string{msgPath}); if (textFile.fail()) { - auto currdir = std::filesystem::current_path().string(); - SUBSTRAIT_FAIL( - "Failed to open file {} when running in {}: {}", - msgPath, - currdir, - strerror(errno)); + auto currDir = std::filesystem::current_path().string(); + return absl::ErrnoToStatus( + errno, + fmt::format( + "Failed to open file {} when running in {}", msgPath, currDir)); } std::stringstream buffer; buffer << textFile.rdbuf(); return buffer.str(); } -PlanOrErrors loadFromJson(std::string_view json) { +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json) { if (json.empty()) { - return PlanOrErrors({"Provided JSON string was empty."}); + return absl::InternalError("Provided JSON string was empty."); } std::string_view usableJson = json; if (json[0] == '#') { @@ -71,21 +70,32 @@ PlanOrErrors loadFromJson(std::string_view json) { std::string{usableJson}, &plan); if (!status.ok()) { std::string msg{status.message()}; - return PlanOrErrors( - {fmt::format("Failed to parse Substrait JSON: {}", msg)}); + return absl::InternalError( + fmt::format("Failed to parse Substrait JSON: {}", msg)); } - return PlanOrErrors(plan); + return plan; } -PlanOrErrors loadFromText(const std::string& text) { +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text) { ::substrait::proto::Plan plan; ::google::protobuf::TextFormat::Parser parser; StringErrorCollector collector; parser.RecordErrorsTo(&collector); if (!parser.ParseFromString(text, &plan)) { - return PlanOrErrors(collector.getErrors()); + auto errors = collector.getErrors(); + return absl::InternalError(absl::StrJoin(errors, "")); + } + return plan; +} + +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes) { + ::substrait::proto::Plan plan; + if (!plan.ParseFromString(bytes)) { + return absl::InternalError("Failed to parse as a binary Substrait plan."); } - return PlanOrErrors(plan); + return plan; } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.h b/src/substrait/textplan/converter/LoadBinary.h index b0b73a90..a4d4e38f 100644 --- a/src/substrait/textplan/converter/LoadBinary.h +++ b/src/substrait/textplan/converter/LoadBinary.h @@ -2,12 +2,9 @@ #pragma once +#include #include #include -#include -#include - -#include "substrait/proto/plan.pb.h" namespace substrait::proto { class Plan; @@ -15,41 +12,21 @@ class Plan; namespace io::substrait::textplan { -// PlanOrErrors behaves similarly to abseil::StatusOr. -class PlanOrErrors { - public: - explicit PlanOrErrors(::substrait::proto::Plan plan) - : plan_(std::move(plan)){}; - explicit PlanOrErrors(std::vector errors) - : errors_(std::move(errors)){}; - - bool ok() { - return errors_.empty(); - } - - const ::substrait::proto::Plan& operator*() { - return plan_; - } - - const std::vector& errors() { - return errors_; - } - - private: - ::substrait::proto::Plan plan_; - std::vector errors_; -}; - // Read the contents of a file from disk. -// Throws an exception if file cannot be read. -std::string readFromFile(std::string_view msgPath); +absl::StatusOr readFromFile(std::string_view msgPath); // Reads a plan from a json-encoded text proto. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromJson(std::string_view json); +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json); // Reads a plan encoded as a text protobuf. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromText(const std::string& text); +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text); + +// Reads a plan serialized as a binary protobuf. +// Returns a list of errors if the file cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes); } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/README.md b/src/substrait/textplan/converter/README.md index 607c0652..53a0e98e 100644 --- a/src/substrait/textplan/converter/README.md +++ b/src/substrait/textplan/converter/README.md @@ -1,7 +1,7 @@ # Using the Plan Converter Tool -The plan converter takes any number of JSON encoded Substrait plan files and converts them into the Substrait Text Plan -format. +The plan converter takes any number of Substrait plan files of any format and +converts them into the Substrait Text Plan format. ## Usage: ``` diff --git a/src/substrait/textplan/converter/SaveBinary.cpp b/src/substrait/textplan/converter/SaveBinary.cpp new file mode 100644 index 00000000..c8dd6c07 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.cpp @@ -0,0 +1,113 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/SaveBinary.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/converter/ParseBinary.h" + +namespace io::substrait::textplan { + +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!plan.SerializeToZeroCopyStream(stream)) { + return ::absl::UnknownError("Failed to write plan to stream."); + } + + if (!stream->Close()) { + return absl::AbortedError("Failed to close file descriptor."); + } + delete stream; + return absl::OkStatus(); +} + +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + std::string output; + auto status = ::google::protobuf::util::MessageToJsonString(plan, &output); + if (!status.ok()) { + return absl::UnknownError("Failed to save plan as a JSON protobuf."); + } + stream << output; + stream.close(); + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as a JSON protobuf."); + } + return absl::OkStatus(); +} + +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + auto result = parseBinaryPlan(plan); + auto errors = result.getAllErrors(); + if (!errors.empty()) { + return absl::UnknownError(absl::StrJoin(errors, "")); + } + stream << SymbolTablePrinter::outputToText(result.getSymbolTable()); + stream.close(); + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as text."); + } + return absl::OkStatus(); +} + +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!::google::protobuf::TextFormat::Print(plan, stream)) { + return absl::UnknownError("Failed to save plan as a text protobuf."); + } + + if (!stream->Close()) { + return absl::AbortedError("Failed to close file descriptor."); + } + delete stream; + return absl::OkStatus(); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/SaveBinary.h b/src/substrait/textplan/converter/SaveBinary.h new file mode 100644 index 00000000..d9158773 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.h @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "absl/status/status.h" + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Serializes a plan to disk as a binary protobuf. +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a JSON-encoded protobuf. +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Calls the converter to store a plan on disk as a text-based substrait plan. +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a text-encoded protobuf. +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp index 9c3d652b..bdcc707a 100644 --- a/src/substrait/textplan/converter/Tool.cpp +++ b/src/substrait/textplan/converter/Tool.cpp @@ -6,6 +6,7 @@ #include +#include "substrait/common/Io.h" #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" @@ -13,14 +14,10 @@ namespace io::substrait::textplan { namespace { -void convertJsonToText(const char* filename) { - std::string json = readFromFile(filename); - auto planOrError = loadFromJson(json); +void convertPlanToText(const char* filename) { + auto planOrError = loadPlan(filename); if (!planOrError.ok()) { - std::cerr << "An error occurred while reading: " << filename << std::endl; - for (const auto& err : planOrError.errors()) { - std::cerr << err << std::endl; - } + std::cerr << planOrError.status() << std::endl; return; } @@ -46,7 +43,7 @@ int main(int argc, char* argv[]) { #ifdef _WIN32 for (int currArg = 1; currArg < argc; currArg++) { printf("===== %s =====\n", argv[currArg]); - io::substrait::textplan::convertJsonToText(argv[currArg]); + io::substrait::textplan::convertPlanToText(argv[currArg]); } #else for (int currArg = 1; currArg < argc; currArg++) { @@ -54,7 +51,7 @@ int main(int argc, char* argv[]) { glob(argv[currArg], GLOB_TILDE, nullptr, &globResult); for (size_t i = 0; i < globResult.gl_pathc; i++) { printf("===== %s =====\n", globResult.gl_pathv[i]); - io::substrait::textplan::convertJsonToText(globResult.gl_pathv[i]); + io::substrait::textplan::convertPlanToText(globResult.gl_pathv[i]); } } #endif diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index cde3c92a..3b0ae97f 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -597,9 +597,10 @@ std::vector getTestCases() { TEST_P(BinaryToTextPlanConverterTestFixture, Parse) { auto [name, input, matcher] = GetParam(); - auto planOrError = loadFromText(input); + auto planOrError = loadFromProtoText(input); if (!planOrError.ok()) { - ParseResult result(SymbolTable(), planOrError.errors(), {}); + ParseResult result( + SymbolTable(), {std::string(planOrError.status().message())}, {}); ASSERT_THAT(result, matcher); return; } @@ -627,13 +628,15 @@ INSTANTIATE_TEST_SUITE_P( class BinaryToTextPlanConversionTest : public ::testing::Test {}; TEST_F(BinaryToTextPlanConversionTest, FullSample) { - std::string json = readFromFile("data/q6_first_stage.json"); - auto planOrError = loadFromJson(json); + auto jsonOrError = readFromFile("data/q6_first_stage.json"); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); ASSERT_TRUE(planOrError.ok()); auto plan = *planOrError; EXPECT_THAT(plan.extensions_size(), ::testing::Eq(7)); - std::string expectedOutput = readFromFile("data/q6_first_stage.golden.splan"); + auto expectedOutputOrError = readFromFile("data/q6_first_stage.golden.splan"); + ASSERT_TRUE(expectedOutputOrError.ok()); auto result = parseBinaryPlan(plan); auto symbols = result.getSymbolTable().getSymbols(); @@ -668,7 +671,7 @@ TEST_F(BinaryToTextPlanConversionTest, FullSample) { SymbolType::kSource, SymbolType::kSchema, }), - WhenSerialized(EqSquashingWhitespace(expectedOutput)))) + WhenSerialized(EqSquashingWhitespace(*expectedOutputOrError)))) << result.getSymbolTable().toDebugString(); } diff --git a/src/substrait/textplan/parser/CMakeLists.txt b/src/substrait/textplan/parser/CMakeLists.txt index 9d1b1143..fcc01413 100644 --- a/src/substrait/textplan/parser/CMakeLists.txt +++ b/src/substrait/textplan/parser/CMakeLists.txt @@ -14,6 +14,8 @@ add_library( SubstraitPlanRelationVisitor.h SubstraitPlanTypeVisitor.cpp SubstraitPlanTypeVisitor.h + LoadText.cpp + LoadText.h ParseText.cpp ParseText.h SubstraitParserErrorListener.cpp) @@ -28,7 +30,9 @@ target_link_libraries( textplan_grammar fmt::fmt-header-only date::date - date::date-tz) + date::date-tz + absl::status + absl::statusor) add_executable(planparser Tool.cpp) diff --git a/src/substrait/textplan/parser/LoadText.cpp b/src/substrait/textplan/parser/LoadText.cpp new file mode 100644 index 00000000..c76a7b31 --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.cpp @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/parser/LoadText.h" + +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/parser/ParseText.h" + +namespace io::substrait::textplan { + +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text) { + auto stream = loadTextString(text); + auto parseResult = io::substrait::textplan::parseStream(stream); + if (!parseResult.successful()) { + auto errors = parseResult.getAllErrors(); + return absl::UnknownError(absl::StrJoin(errors, "")); + } + + return SymbolTablePrinter::outputToBinaryPlan(parseResult.getSymbolTable()); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/LoadText.h b/src/substrait/textplan/parser/LoadText.h new file mode 100644 index 00000000..3b2f900b --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Reads a plan encoded as a text protobuf. +// Returns a list of errors if the text cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 3a19d3cc..a4ba32fc 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: Apache-2.0 */ -#include "ParseText.h" +#include "substrait/textplan/parser/ParseText.h" #include #include diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index a8273f12..4a6d5567 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -974,14 +974,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( case SourceType::kNamedTable: { auto* source = parentRelationData->relation.mutable_read()->mutable_named_table(); - for (const auto& sym : *symbolTable_) { - if (sym.type != SymbolType::kSourceDetail) { - continue; - } - if (sym.location != symbol->location) { + for (const auto& sym : + symbolTable_->lookupSymbolsByLocation(symbol->location)) { + if (sym->type != SymbolType::kSourceDetail) { continue; } - source->add_names(sym.name); + source->add_names(sym->name); } break; } diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 7059b934..42f447c0 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -69,12 +69,12 @@ std::vector getTestCases() { TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { auto filename = GetParam(); - std::string json = readFromFile(filename); - auto planOrErrors = loadFromJson(json); - std::vector errors = planOrErrors.errors(); - ASSERT_THAT(errors, ::testing::ElementsAre()); + auto jsonOrError = readFromFile(filename); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); + ASSERT_TRUE(planOrError.ok()); - auto plan = *planOrErrors; + auto plan = *planOrError; auto textResult = parseBinaryPlan(plan); auto textSymbols = textResult.getSymbolTable().getSymbols();