From 2dfee3b7cdaeb6c95ea5d5f3242ac8d49350dacf Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 20 Jul 2023 10:54:52 -0700 Subject: [PATCH 01/11] feat: add api for loading plans of all types features: * supports reading/writing all of the major formats caveats: * only read/writes to filenames so in-memory use should use other interfaces * does not support compression * does not have a zero copy interface --- include/substrait/common/Io.h | 30 +++++ src/substrait/common/CMakeLists.txt | 13 ++- src/substrait/common/Io.cpp | 77 +++++++++++++ src/substrait/common/tests/CMakeLists.txt | 10 ++ src/substrait/common/tests/IoTest.cpp | 104 +++++++++++++++++ src/substrait/textplan/StringManipulation.cpp | 19 ++- src/substrait/textplan/StringManipulation.h | 7 ++ src/substrait/textplan/SymbolTablePrinter.cpp | 25 +++- .../textplan/converter/CMakeLists.txt | 8 +- .../textplan/converter/LoadBinary.cpp | 45 +++++--- src/substrait/textplan/converter/LoadBinary.h | 43 ++----- src/substrait/textplan/converter/README.md | 4 +- .../textplan/converter/SaveBinary.cpp | 109 ++++++++++++++++++ src/substrait/textplan/converter/SaveBinary.h | 33 ++++++ src/substrait/textplan/converter/Tool.cpp | 15 +-- .../tests/BinaryToTextPlanConversionTest.cpp | 15 ++- src/substrait/textplan/parser/CMakeLists.txt | 6 +- src/substrait/textplan/parser/LoadText.cpp | 23 ++++ src/substrait/textplan/parser/LoadText.h | 17 +++ src/substrait/textplan/parser/ParseText.cpp | 2 +- .../parser/SubstraitPlanRelationVisitor.cpp | 10 +- .../textplan/tests/RoundtripTest.cpp | 10 +- 22 files changed, 536 insertions(+), 89 deletions(-) create mode 100644 include/substrait/common/Io.h create mode 100644 src/substrait/common/Io.cpp create mode 100644 src/substrait/common/tests/IoTest.cpp create mode 100644 src/substrait/textplan/converter/SaveBinary.cpp create mode 100644 src/substrait/textplan/converter/SaveBinary.h create mode 100644 src/substrait/textplan/parser/LoadText.cpp create mode 100644 src/substrait/textplan/parser/LoadText.h diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h new file mode 100644 index 00000000..4227f7a3 --- /dev/null +++ b/include/substrait/common/Io.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "absl/status/statusor.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait { + +enum PlanFileEncoding { + kBinary = 0, + kJson = 1, + kProtoText = 2, + kText = 3, +}; + +// Loads a Substrait plan consisting of any encoding type from the given file. +absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( + std::string_view input_filename); + +// Writes the provided plan to the specified location with the specified +// encoding type. +[[maybe_unused]] absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileEncoding encoding); + +} // namespace io::substrait diff --git a/src/substrait/common/CMakeLists.txt b/src/substrait/common/CMakeLists.txt index 846d077f..dc05a11e 100644 --- a/src/substrait/common/CMakeLists.txt +++ b/src/substrait/common/CMakeLists.txt @@ -1,9 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 add_library(substrait_common Exceptions.cpp) - target_link_libraries(substrait_common fmt::fmt-header-only) +add_library(substrait_io Io.cpp) +add_dependencies( + substrait_io + substrait_proto + substrait_textplan_converter + substrait_textplan_loader + fmt::fmt-header-only + absl::status + absl::statusor) +target_link_libraries(substrait_io substrait_proto substrait_textplan_converter + substrait_textplan_loader absl::status absl::statusor) + if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp new file mode 100644 index 00000000..d017a58c --- /dev/null +++ b/src/substrait/common/Io.cpp @@ -0,0 +1,77 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/converter/LoadBinary.h" +#include "substrait/textplan/converter/SaveBinary.h" +#include "substrait/textplan/parser/LoadText.h" + +namespace io::substrait { + +namespace { + +const std::regex kIsJson(R"(("extensionUris"|"extensions"|"relations"))"); +const std::regex kIsProtoText( + R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))"); +const std::regex kIsText( + R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)"); + +PlanFileEncoding detectEncoding(std::string_view content) { + if (std::regex_search(content.begin(), content.end(), kIsJson)) { + return kJson; + } + if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { + return kProtoText; + } + if (std::regex_search(content.begin(), content.end(), kIsText)) { + return kText; + } + return kBinary; +} + +} // namespace + +absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( + std::string_view input_filename) { + auto contentOrError = textplan::readFromFile(input_filename.data()); + if (!contentOrError.ok()) { + return contentOrError.status(); + } + + auto encoding = detectEncoding(*contentOrError); + absl::StatusOr<::substrait::proto::Plan> planOrError; + switch (encoding) { + case kBinary: + return textplan::loadFromBinary(*contentOrError); + case kJson: + return textplan::loadFromJson(*contentOrError); + case kProtoText: + return textplan::loadFromProtoText(*contentOrError); + case kText: + return textplan::loadFromText(*contentOrError); + } + return absl::UnimplementedError("Unexpected encoding requested."); +} + +absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileEncoding encoding) { + switch (encoding) { + case kBinary: + return textplan::savePlanToBinary(plan, output_filename); + case kJson: + return textplan::savePlanToJson(plan, output_filename); + case kProtoText: + return textplan::savePlanToProtoText(plan, output_filename); + case kText: + return textplan::savePlanToText(plan, output_filename); + } + return absl::UnimplementedError("Unexpected encoding requested."); +} + +} // namespace io::substrait diff --git a/src/substrait/common/tests/CMakeLists.txt b/src/substrait/common/tests/CMakeLists.txt index 36c5f160..88b63b8f 100644 --- a/src/substrait/common/tests/CMakeLists.txt +++ b/src/substrait/common/tests/CMakeLists.txt @@ -9,3 +9,13 @@ add_test_case( substrait_common gtest gtest_main) + +add_test_case( + substrait_io_test + SOURCES + IoTest.cpp + EXTRA_LINK_LIBS + substrait_io + protobuf-matchers + gtest + gtest_main) diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp new file mode 100644 index 00000000..a9e7f13f --- /dev/null +++ b/src/substrait/common/tests/IoTest.cpp @@ -0,0 +1,104 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include +#include +#include +#include + +using ::protobuf_matchers::EqualsProto; +using ::protobuf_matchers::Partially; + +namespace io::substrait { + +namespace { + +constexpr const char* planFileEncodingToString(PlanFileEncoding e) noexcept { + switch (e) { + case PlanFileEncoding::kBinary: + return "kBinary"; + case PlanFileEncoding::kJson: + return "kJson"; + case PlanFileEncoding::kProtoText: + return "kProtoText"; + case PlanFileEncoding::kText: + return "kText"; + } + return "IMPOSSIBLE"; +} + +} // namespace + +class IoTest : public ::testing::Test {}; + +TEST_F(IoTest, LoadMissingFile) { + auto result = + ::io::substrait::loadPlanWithUnknownEncoding("non-existent-file"); + ASSERT_FALSE(result.ok()); + ASSERT_THAT( + result.status().message(), + ::testing::ContainsRegex("Failed to open file non-existent-file")); +} + +class SaveAndLoadTestFixture + : public ::testing::TestWithParam { + public: + ~SaveAndLoadTestFixture() override { + for (const auto& filename : testFiles_) { + unlink(filename.c_str()); + } + } + + void registerCleanup(const char* filename) { + testFiles_.emplace_back(filename); + } + + private: + std::vector testFiles_; +}; + +TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { + auto tempFilename = std::tmpnam(nullptr); + registerCleanup(tempFilename); + PlanFileEncoding encoding = GetParam(); + + ::substrait::proto::Plan plan; + auto root = plan.add_relations()->mutable_root(); + auto read = root->mutable_input()->mutable_read(); + read->mutable_common()->mutable_direct(); + read->mutable_named_table()->add_names("table_name"); + auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); + ASSERT_TRUE(status.ok()) << "Save failed.\n" << status; + + auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename); + ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status(); + ASSERT_THAT( + *result, + Partially(EqualsProto<::substrait::proto::Plan>( + R"(relations { + root { + input { + read { + common { + direct { + } + } + named_table { + names: "table_name" + } + } + } + } + })"))); +} + +INSTANTIATE_TEST_SUITE_P( + SaveAndLoadTests, + SaveAndLoadTestFixture, + testing::Values(kBinary, kJson, kProtoText, kText), + [](const testing::TestParamInfo& info) { + return planFileEncodingToString(info.param); + }); + +} // namespace io::substrait diff --git a/src/substrait/textplan/StringManipulation.cpp b/src/substrait/textplan/StringManipulation.cpp index eac3c56a..ce37c9bb 100644 --- a/src/substrait/textplan/StringManipulation.cpp +++ b/src/substrait/textplan/StringManipulation.cpp @@ -2,18 +2,33 @@ #include "StringManipulation.h" +#include +#include +#include +#include + namespace io::substrait::textplan { -// Yields true if the string 'haystack' starts with the string 'needle'. bool startsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(0, needle.size()) == needle; } -// Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; } +std::string joinLines( + std::vector lines, + std::string_view separator) { + auto concatWithSeparator = [separator](std::string a, const std::string& b) { + return std::move(a) + std::string(separator) + b; + }; + + auto result = std::accumulate( + std::next(lines.begin()), lines.end(), lines[0], concatWithSeparator); + return result; +} + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/StringManipulation.h b/src/substrait/textplan/StringManipulation.h index 9c24418f..d0d602e6 100644 --- a/src/substrait/textplan/StringManipulation.h +++ b/src/substrait/textplan/StringManipulation.h @@ -2,7 +2,9 @@ #pragma once +#include #include +#include namespace io::substrait::textplan { @@ -12,4 +14,9 @@ bool startsWith(std::string_view haystack, std::string_view needle); // Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle); +// Joins a vector of strings into a single string separated by separator. +std::string joinLines( + std::vector lines, + std::string_view separator = "\n"); + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 19f96188..1b460624 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) { auto subtype = ANY_CAST(SourceType, info.subtype); switch (subtype) { case SourceType::kNamedTable: { - auto table = - ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob); + if (info.blob.has_value()) { + // We are using the proto as is in lieu of a disciplined structure. + auto table = ANY_CAST( + const ::substrait::proto::ReadRel_NamedTable*, info.blob); + text << "source named_table " << info.name << " {\n"; + text << " names = [\n"; + for (const auto& name : table->names()) { + text << " \"" << name << "\",\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + // We are using the new style data structure. text << "source named_table " << info.name << " {\n"; text << " names = [\n"; - for (const auto& name : table->names()) { - text << " \"" << name << "\",\n"; + for (const auto& sym : + symbolTable.lookupSymbolsByLocation(info.location)) { + if (sym->type == SymbolType::kSourceDetail) { + text << " \"" << sym->name << "\",\n"; + } } text << " ]\n"; text << "}\n"; hasPreviousText = true; + break; } case SourceType::kLocalFiles: { diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 41f3b9ce..6f67271e 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -11,6 +11,8 @@ set(TEXTPLAN_SRCS PlanPrinterVisitor.h LoadBinary.cpp LoadBinary.h + SaveBinary.cpp + SaveBinary.h ParseBinary.cpp ParseBinary.h) @@ -20,10 +22,14 @@ target_link_libraries( substrait_textplan_converter substrait_common substrait_expression + substrait_io substrait_proto symbol_table error_listener - date::date) + date::date + fmt::fmt-header-only + absl::status + absl::statusor) if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) diff --git a/src/substrait/textplan/converter/LoadBinary.cpp b/src/substrait/textplan/converter/LoadBinary.cpp index 787ab054..65e5694c 100644 --- a/src/substrait/textplan/converter/LoadBinary.cpp +++ b/src/substrait/textplan/converter/LoadBinary.cpp @@ -2,20 +2,19 @@ #include "substrait/textplan/converter/LoadBinary.h" +#include #include #include #include - #include #include -#include #include #include #include #include -#include "substrait/common/Exceptions.h" #include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" namespace io::substrait::textplan { @@ -39,24 +38,23 @@ class StringErrorCollector : public google::protobuf::io::ErrorCollector { } // namespace -std::string readFromFile(std::string_view msgPath) { +absl::StatusOr readFromFile(std::string_view msgPath) { std::ifstream textFile(std::string{msgPath}); if (textFile.fail()) { - auto currdir = std::filesystem::current_path().string(); - SUBSTRAIT_FAIL( - "Failed to open file {} when running in {}: {}", - msgPath, - currdir, - strerror(errno)); + auto currDir = std::filesystem::current_path().string(); + return absl::ErrnoToStatus( + errno, + fmt::format( + "Failed to open file {} when running in {}", msgPath, currDir)); } std::stringstream buffer; buffer << textFile.rdbuf(); return buffer.str(); } -PlanOrErrors loadFromJson(std::string_view json) { +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json) { if (json.empty()) { - return PlanOrErrors({"Provided JSON string was empty."}); + return absl::InternalError("Provided JSON string was empty."); } std::string_view usableJson = json; if (json[0] == '#') { @@ -71,21 +69,32 @@ PlanOrErrors loadFromJson(std::string_view json) { std::string{usableJson}, &plan); if (!status.ok()) { std::string msg{status.message()}; - return PlanOrErrors( - {fmt::format("Failed to parse Substrait JSON: {}", msg)}); + return absl::InternalError( + fmt::format("Failed to parse Substrait JSON: {}", msg)); } - return PlanOrErrors(plan); + return plan; } -PlanOrErrors loadFromText(const std::string& text) { +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text) { ::substrait::proto::Plan plan; ::google::protobuf::TextFormat::Parser parser; StringErrorCollector collector; parser.RecordErrorsTo(&collector); if (!parser.ParseFromString(text, &plan)) { - return PlanOrErrors(collector.getErrors()); + auto errors = collector.getErrors(); + return absl::InternalError(joinLines(errors)); + } + return plan; +} + +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes) { + ::substrait::proto::Plan plan; + if (!plan.ParseFromString(bytes)) { + return absl::InternalError("Failed to parse as a binary Substrait plan."); } - return PlanOrErrors(plan); + return plan; } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.h b/src/substrait/textplan/converter/LoadBinary.h index b0b73a90..a4d4e38f 100644 --- a/src/substrait/textplan/converter/LoadBinary.h +++ b/src/substrait/textplan/converter/LoadBinary.h @@ -2,12 +2,9 @@ #pragma once +#include #include #include -#include -#include - -#include "substrait/proto/plan.pb.h" namespace substrait::proto { class Plan; @@ -15,41 +12,21 @@ class Plan; namespace io::substrait::textplan { -// PlanOrErrors behaves similarly to abseil::StatusOr. -class PlanOrErrors { - public: - explicit PlanOrErrors(::substrait::proto::Plan plan) - : plan_(std::move(plan)){}; - explicit PlanOrErrors(std::vector errors) - : errors_(std::move(errors)){}; - - bool ok() { - return errors_.empty(); - } - - const ::substrait::proto::Plan& operator*() { - return plan_; - } - - const std::vector& errors() { - return errors_; - } - - private: - ::substrait::proto::Plan plan_; - std::vector errors_; -}; - // Read the contents of a file from disk. -// Throws an exception if file cannot be read. -std::string readFromFile(std::string_view msgPath); +absl::StatusOr readFromFile(std::string_view msgPath); // Reads a plan from a json-encoded text proto. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromJson(std::string_view json); +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json); // Reads a plan encoded as a text protobuf. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromText(const std::string& text); +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text); + +// Reads a plan serialized as a binary protobuf. +// Returns a list of errors if the file cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes); } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/README.md b/src/substrait/textplan/converter/README.md index 607c0652..53a0e98e 100644 --- a/src/substrait/textplan/converter/README.md +++ b/src/substrait/textplan/converter/README.md @@ -1,7 +1,7 @@ # Using the Plan Converter Tool -The plan converter takes any number of JSON encoded Substrait plan files and converts them into the Substrait Text Plan -format. +The plan converter takes any number of Substrait plan files of any format and +converts them into the Substrait Text Plan format. ## Usage: ``` diff --git a/src/substrait/textplan/converter/SaveBinary.cpp b/src/substrait/textplan/converter/SaveBinary.cpp new file mode 100644 index 00000000..e7347269 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.cpp @@ -0,0 +1,109 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/SaveBinary.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/converter/ParseBinary.h" + +namespace io::substrait::textplan { + +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!plan.SerializeToZeroCopyStream(stream)) { + return ::absl::UnknownError("Failed to write plan to stream."); + } + + delete stream; + close(outputFileDescriptor); + return absl::OkStatus(); +} + +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + std::string output; + auto status = ::google::protobuf::util::MessageToJsonString(plan, &output); + if (!status.ok()) { + return absl::UnknownError("Failed to save plan as a JSON protobuf."); + } + stream << output; + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as a JSON protobuf."); + } + + stream.close(); + return absl::OkStatus(); +} + +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + auto result = parseBinaryPlan(plan); + auto errors = result.getAllErrors(); + if (!errors.empty()) { + return absl::UnknownError(joinLines(errors)); + } + stream << SymbolTablePrinter::outputToText(result.getSymbolTable()); + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as text."); + } + stream.close(); + return absl::OkStatus(); +} + +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!::google::protobuf::TextFormat::Print(plan, stream)) { + return absl::UnknownError("Failed to save plan as a text protobuf."); + } + + delete stream; + close(outputFileDescriptor); + return absl::OkStatus(); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/SaveBinary.h b/src/substrait/textplan/converter/SaveBinary.h new file mode 100644 index 00000000..d9158773 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.h @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "absl/status/status.h" + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Serializes a plan to disk as a binary protobuf. +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a JSON-encoded protobuf. +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Calls the converter to store a plan on disk as a text-based substrait plan. +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a text-encoded protobuf. +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp index 9c3d652b..67d6ca82 100644 --- a/src/substrait/textplan/converter/Tool.cpp +++ b/src/substrait/textplan/converter/Tool.cpp @@ -6,6 +6,7 @@ #include +#include "substrait/common/Io.h" #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" @@ -13,14 +14,10 @@ namespace io::substrait::textplan { namespace { -void convertJsonToText(const char* filename) { - std::string json = readFromFile(filename); - auto planOrError = loadFromJson(json); +void convertPlanToText(const char* filename) { + auto planOrError = loadPlanWithUnknownEncoding(filename); if (!planOrError.ok()) { - std::cerr << "An error occurred while reading: " << filename << std::endl; - for (const auto& err : planOrError.errors()) { - std::cerr << err << std::endl; - } + std::cerr << planOrError.status() << std::endl; return; } @@ -46,7 +43,7 @@ int main(int argc, char* argv[]) { #ifdef _WIN32 for (int currArg = 1; currArg < argc; currArg++) { printf("===== %s =====\n", argv[currArg]); - io::substrait::textplan::convertJsonToText(argv[currArg]); + io::substrait::textplan::convertPlanToText(argv[currArg]); } #else for (int currArg = 1; currArg < argc; currArg++) { @@ -54,7 +51,7 @@ int main(int argc, char* argv[]) { glob(argv[currArg], GLOB_TILDE, nullptr, &globResult); for (size_t i = 0; i < globResult.gl_pathc; i++) { printf("===== %s =====\n", globResult.gl_pathv[i]); - io::substrait::textplan::convertJsonToText(globResult.gl_pathv[i]); + io::substrait::textplan::convertPlanToText(globResult.gl_pathv[i]); } } #endif diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index cde3c92a..3b0ae97f 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -597,9 +597,10 @@ std::vector getTestCases() { TEST_P(BinaryToTextPlanConverterTestFixture, Parse) { auto [name, input, matcher] = GetParam(); - auto planOrError = loadFromText(input); + auto planOrError = loadFromProtoText(input); if (!planOrError.ok()) { - ParseResult result(SymbolTable(), planOrError.errors(), {}); + ParseResult result( + SymbolTable(), {std::string(planOrError.status().message())}, {}); ASSERT_THAT(result, matcher); return; } @@ -627,13 +628,15 @@ INSTANTIATE_TEST_SUITE_P( class BinaryToTextPlanConversionTest : public ::testing::Test {}; TEST_F(BinaryToTextPlanConversionTest, FullSample) { - std::string json = readFromFile("data/q6_first_stage.json"); - auto planOrError = loadFromJson(json); + auto jsonOrError = readFromFile("data/q6_first_stage.json"); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); ASSERT_TRUE(planOrError.ok()); auto plan = *planOrError; EXPECT_THAT(plan.extensions_size(), ::testing::Eq(7)); - std::string expectedOutput = readFromFile("data/q6_first_stage.golden.splan"); + auto expectedOutputOrError = readFromFile("data/q6_first_stage.golden.splan"); + ASSERT_TRUE(expectedOutputOrError.ok()); auto result = parseBinaryPlan(plan); auto symbols = result.getSymbolTable().getSymbols(); @@ -668,7 +671,7 @@ TEST_F(BinaryToTextPlanConversionTest, FullSample) { SymbolType::kSource, SymbolType::kSchema, }), - WhenSerialized(EqSquashingWhitespace(expectedOutput)))) + WhenSerialized(EqSquashingWhitespace(*expectedOutputOrError)))) << result.getSymbolTable().toDebugString(); } diff --git a/src/substrait/textplan/parser/CMakeLists.txt b/src/substrait/textplan/parser/CMakeLists.txt index 498d572c..4c89fa83 100644 --- a/src/substrait/textplan/parser/CMakeLists.txt +++ b/src/substrait/textplan/parser/CMakeLists.txt @@ -12,6 +12,8 @@ add_library( SubstraitPlanRelationVisitor.h SubstraitPlanTypeVisitor.cpp SubstraitPlanTypeVisitor.h + LoadText.cpp + LoadText.h ParseText.cpp ParseText.h SubstraitParserErrorListener.cpp) @@ -26,7 +28,9 @@ target_link_libraries( textplan_grammar fmt::fmt-header-only date::date - date::date-tz) + date::date-tz + absl::status + absl::statusor) add_executable(planparser Tool.cpp) diff --git a/src/substrait/textplan/parser/LoadText.cpp b/src/substrait/textplan/parser/LoadText.cpp new file mode 100644 index 00000000..0850befe --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.cpp @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/parser/LoadText.h" + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/parser/ParseText.h" + +namespace io::substrait::textplan { + +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text) { + auto stream = loadTextString(text); + auto parseResult = io::substrait::textplan::parseStream(stream); + if (!parseResult.successful()) { + auto errors = parseResult.getAllErrors(); + return absl::UnknownError(joinLines(errors)); + } + + return SymbolTablePrinter::outputToBinaryPlan(parseResult.getSymbolTable()); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/LoadText.h b/src/substrait/textplan/parser/LoadText.h new file mode 100644 index 00000000..3b2f900b --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Reads a plan encoded as a text protobuf. +// Returns a list of errors if the text cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 3a19d3cc..a4ba32fc 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: Apache-2.0 */ -#include "ParseText.h" +#include "substrait/textplan/parser/ParseText.h" #include #include diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index a8273f12..4a6d5567 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -974,14 +974,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( case SourceType::kNamedTable: { auto* source = parentRelationData->relation.mutable_read()->mutable_named_table(); - for (const auto& sym : *symbolTable_) { - if (sym.type != SymbolType::kSourceDetail) { - continue; - } - if (sym.location != symbol->location) { + for (const auto& sym : + symbolTable_->lookupSymbolsByLocation(symbol->location)) { + if (sym->type != SymbolType::kSourceDetail) { continue; } - source->add_names(sym.name); + source->add_names(sym->name); } break; } diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 3b22071c..09aeeca3 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -70,12 +70,12 @@ std::vector getTestCases() { TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { auto filename = GetParam(); - std::string json = readFromFile(filename); - auto planOrErrors = loadFromJson(json); - std::vector errors = planOrErrors.errors(); - ASSERT_THAT(errors, ::testing::ElementsAre()); + auto jsonOrError = readFromFile(filename); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); + ASSERT_TRUE(planOrError.ok()); - auto plan = *planOrErrors; + auto plan = *planOrError; auto textResult = parseBinaryPlan(plan); auto textSymbols = textResult.getSymbolTable().getSymbols(); From 6fa011e03d152c619c223d9552b5bf848a8212b8 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 31 Jul 2023 10:22:09 -0700 Subject: [PATCH 02/11] Addressed some of the review notes. --- include/substrait/common/Io.h | 4 ++-- src/substrait/common/Io.cpp | 29 ++++++++++++----------- src/substrait/common/tests/IoTest.cpp | 11 +++++---- src/substrait/textplan/converter/Tool.cpp | 2 +- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h index 4227f7a3..3e983151 100644 --- a/include/substrait/common/Io.h +++ b/include/substrait/common/Io.h @@ -9,7 +9,7 @@ namespace io::substrait { -enum PlanFileEncoding { +enum class PlanFileEncoding { kBinary = 0, kJson = 1, kProtoText = 2, @@ -17,7 +17,7 @@ enum PlanFileEncoding { }; // Loads a Substrait plan consisting of any encoding type from the given file. -absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( +absl::StatusOr<::substrait::proto::Plan> loadPlan( std::string_view input_filename); // Writes the provided plan to the specified location with the specified diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp index d017a58c..776a6de4 100644 --- a/src/substrait/common/Io.cpp +++ b/src/substrait/common/Io.cpp @@ -14,7 +14,8 @@ namespace io::substrait { namespace { -const std::regex kIsJson(R"(("extensionUris"|"extensions"|"relations"))"); +const std::regex kIsJson( + R"(("extensionUris"|"extension_uris"|"extensions"|"relations"))"); const std::regex kIsProtoText( R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))"); const std::regex kIsText( @@ -22,20 +23,20 @@ const std::regex kIsText( PlanFileEncoding detectEncoding(std::string_view content) { if (std::regex_search(content.begin(), content.end(), kIsJson)) { - return kJson; + return PlanFileEncoding::kJson; } if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { - return kProtoText; + return PlanFileEncoding::kProtoText; } if (std::regex_search(content.begin(), content.end(), kIsText)) { - return kText; + return PlanFileEncoding::kText; } - return kBinary; + return PlanFileEncoding::kBinary; } } // namespace -absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( +absl::StatusOr<::substrait::proto::Plan> loadPlan( std::string_view input_filename) { auto contentOrError = textplan::readFromFile(input_filename.data()); if (!contentOrError.ok()) { @@ -45,13 +46,13 @@ absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( auto encoding = detectEncoding(*contentOrError); absl::StatusOr<::substrait::proto::Plan> planOrError; switch (encoding) { - case kBinary: + case PlanFileEncoding::kBinary: return textplan::loadFromBinary(*contentOrError); - case kJson: + case PlanFileEncoding::kJson: return textplan::loadFromJson(*contentOrError); - case kProtoText: + case PlanFileEncoding::kProtoText: return textplan::loadFromProtoText(*contentOrError); - case kText: + case PlanFileEncoding::kText: return textplan::loadFromText(*contentOrError); } return absl::UnimplementedError("Unexpected encoding requested."); @@ -62,13 +63,13 @@ absl::Status savePlan( std::string_view output_filename, PlanFileEncoding encoding) { switch (encoding) { - case kBinary: + case PlanFileEncoding::kBinary: return textplan::savePlanToBinary(plan, output_filename); - case kJson: + case PlanFileEncoding::kJson: return textplan::savePlanToJson(plan, output_filename); - case kProtoText: + case PlanFileEncoding::kProtoText: return textplan::savePlanToProtoText(plan, output_filename); - case kText: + case PlanFileEncoding::kText: return textplan::savePlanToText(plan, output_filename); } return absl::UnimplementedError("Unexpected encoding requested."); diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp index a9e7f13f..5a961f0e 100644 --- a/src/substrait/common/tests/IoTest.cpp +++ b/src/substrait/common/tests/IoTest.cpp @@ -33,8 +33,7 @@ constexpr const char* planFileEncodingToString(PlanFileEncoding e) noexcept { class IoTest : public ::testing::Test {}; TEST_F(IoTest, LoadMissingFile) { - auto result = - ::io::substrait::loadPlanWithUnknownEncoding("non-existent-file"); + auto result = ::io::substrait::loadPlan("non-existent-file"); ASSERT_FALSE(result.ok()); ASSERT_THAT( result.status().message(), @@ -71,7 +70,7 @@ TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); ASSERT_TRUE(status.ok()) << "Save failed.\n" << status; - auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename); + auto result = ::io::substrait::loadPlan(tempFilename); ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status(); ASSERT_THAT( *result, @@ -96,7 +95,11 @@ TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { INSTANTIATE_TEST_SUITE_P( SaveAndLoadTests, SaveAndLoadTestFixture, - testing::Values(kBinary, kJson, kProtoText, kText), + testing::Values( + PlanFileEncoding::kBinary, + PlanFileEncoding::kJson, + PlanFileEncoding::kProtoText, + PlanFileEncoding::kText), [](const testing::TestParamInfo& info) { return planFileEncodingToString(info.param); }); diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp index 67d6ca82..bdcc707a 100644 --- a/src/substrait/textplan/converter/Tool.cpp +++ b/src/substrait/textplan/converter/Tool.cpp @@ -15,7 +15,7 @@ namespace io::substrait::textplan { namespace { void convertPlanToText(const char* filename) { - auto planOrError = loadPlanWithUnknownEncoding(filename); + auto planOrError = loadPlan(filename); if (!planOrError.ok()) { std::cerr << planOrError.status() << std::endl; return; From cceac7961e4f5b524d867e0acab60f8ca0c5687b Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 31 Jul 2023 15:55:27 -0700 Subject: [PATCH 03/11] Addressed the rest of this set of review notes. --- include/substrait/common/Io.h | 42 +++++++++++++++++++++++---- src/substrait/common/Io.cpp | 34 +++++++++++----------- src/substrait/common/tests/IoTest.cpp | 23 +++++++-------- 3 files changed, 64 insertions(+), 35 deletions(-) diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h index 3e983151..a2c18726 100644 --- a/include/substrait/common/Io.h +++ b/include/substrait/common/Io.h @@ -9,22 +9,52 @@ namespace io::substrait { -enum class PlanFileEncoding { +/* + * \brief The four different ways plans can be represented on disk. + */ +enum class PlanFileFormat { kBinary = 0, kJson = 1, kProtoText = 2, kText = 3, }; -// Loads a Substrait plan consisting of any encoding type from the given file. +/* + * \\brief Loads a Substrait plan of any format from the given file. + * + * loadPlan determines which file type the specified file is and then calls + * the appropriate load/parse method to consume it preserving any error + * messages. + * + * This will load the plan into memory and then convert it consuming twice the + * amount of memory that it consumed on disk. + * + * \\param input_filename The filename containing the plan to convert. + * \\return If loading was successful, returns a plan. If loading was not + * successful this is a status containing a list of parse errors in the status's + * message. + */ absl::StatusOr<::substrait::proto::Plan> loadPlan( std::string_view input_filename); -// Writes the provided plan to the specified location with the specified -// encoding type. -[[maybe_unused]] absl::Status savePlan( +/* + * \\brief Writes the provided plan to disk. + * + * savePlan writes the provided plan in the specified format to the specified + * location. + * + * This routine will consume more memory during the conversion to the text + * format as the original plan as well as the annotated parse tree will need to + * reside in memory during the process. + * + * \\param plan + * \\param output_filename + * \\param format + * \\return + */ +absl::Status savePlan( const ::substrait::proto::Plan& plan, std::string_view output_filename, - PlanFileEncoding encoding); + PlanFileFormat format); } // namespace io::substrait diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp index 776a6de4..4c5b1c24 100644 --- a/src/substrait/common/Io.cpp +++ b/src/substrait/common/Io.cpp @@ -21,17 +21,17 @@ const std::regex kIsProtoText( const std::regex kIsText( R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)"); -PlanFileEncoding detectEncoding(std::string_view content) { +PlanFileFormat detectFormat(std::string_view content) { if (std::regex_search(content.begin(), content.end(), kIsJson)) { - return PlanFileEncoding::kJson; + return PlanFileFormat::kJson; } if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { - return PlanFileEncoding::kProtoText; + return PlanFileFormat::kProtoText; } if (std::regex_search(content.begin(), content.end(), kIsText)) { - return PlanFileEncoding::kText; + return PlanFileFormat::kText; } - return PlanFileEncoding::kBinary; + return PlanFileFormat::kBinary; } } // namespace @@ -43,16 +43,16 @@ absl::StatusOr<::substrait::proto::Plan> loadPlan( return contentOrError.status(); } - auto encoding = detectEncoding(*contentOrError); + auto encoding = detectFormat(*contentOrError); absl::StatusOr<::substrait::proto::Plan> planOrError; switch (encoding) { - case PlanFileEncoding::kBinary: + case PlanFileFormat::kBinary: return textplan::loadFromBinary(*contentOrError); - case PlanFileEncoding::kJson: + case PlanFileFormat::kJson: return textplan::loadFromJson(*contentOrError); - case PlanFileEncoding::kProtoText: + case PlanFileFormat::kProtoText: return textplan::loadFromProtoText(*contentOrError); - case PlanFileEncoding::kText: + case PlanFileFormat::kText: return textplan::loadFromText(*contentOrError); } return absl::UnimplementedError("Unexpected encoding requested."); @@ -61,18 +61,18 @@ absl::StatusOr<::substrait::proto::Plan> loadPlan( absl::Status savePlan( const ::substrait::proto::Plan& plan, std::string_view output_filename, - PlanFileEncoding encoding) { - switch (encoding) { - case PlanFileEncoding::kBinary: + PlanFileFormat format) { + switch (format) { + case PlanFileFormat::kBinary: return textplan::savePlanToBinary(plan, output_filename); - case PlanFileEncoding::kJson: + case PlanFileFormat::kJson: return textplan::savePlanToJson(plan, output_filename); - case PlanFileEncoding::kProtoText: + case PlanFileFormat::kProtoText: return textplan::savePlanToProtoText(plan, output_filename); - case PlanFileEncoding::kText: + case PlanFileFormat::kText: return textplan::savePlanToText(plan, output_filename); } - return absl::UnimplementedError("Unexpected encoding requested."); + return absl::UnimplementedError("Unexpected format requested."); } } // namespace io::substrait diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp index 5a961f0e..bf8dd00d 100644 --- a/src/substrait/common/tests/IoTest.cpp +++ b/src/substrait/common/tests/IoTest.cpp @@ -14,15 +14,15 @@ namespace io::substrait { namespace { -constexpr const char* planFileEncodingToString(PlanFileEncoding e) noexcept { +constexpr const char* planFileEncodingToString(PlanFileFormat e) noexcept { switch (e) { - case PlanFileEncoding::kBinary: + case PlanFileFormat::kBinary: return "kBinary"; - case PlanFileEncoding::kJson: + case PlanFileFormat::kJson: return "kJson"; - case PlanFileEncoding::kProtoText: + case PlanFileFormat::kProtoText: return "kProtoText"; - case PlanFileEncoding::kText: + case PlanFileFormat::kText: return "kText"; } return "IMPOSSIBLE"; @@ -40,8 +40,7 @@ TEST_F(IoTest, LoadMissingFile) { ::testing::ContainsRegex("Failed to open file non-existent-file")); } -class SaveAndLoadTestFixture - : public ::testing::TestWithParam { +class SaveAndLoadTestFixture : public ::testing::TestWithParam { public: ~SaveAndLoadTestFixture() override { for (const auto& filename : testFiles_) { @@ -60,7 +59,7 @@ class SaveAndLoadTestFixture TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { auto tempFilename = std::tmpnam(nullptr); registerCleanup(tempFilename); - PlanFileEncoding encoding = GetParam(); + PlanFileFormat encoding = GetParam(); ::substrait::proto::Plan plan; auto root = plan.add_relations()->mutable_root(); @@ -96,10 +95,10 @@ INSTANTIATE_TEST_SUITE_P( SaveAndLoadTests, SaveAndLoadTestFixture, testing::Values( - PlanFileEncoding::kBinary, - PlanFileEncoding::kJson, - PlanFileEncoding::kProtoText, - PlanFileEncoding::kText), + PlanFileFormat::kBinary, + PlanFileFormat::kJson, + PlanFileFormat::kProtoText, + PlanFileFormat::kText), [](const testing::TestParamInfo& info) { return planFileEncodingToString(info.param); }); From c29dbbb50fb36ae9d22ea8b4b357f5b91c96d55f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 16 Aug 2023 13:55:48 -0700 Subject: [PATCH 04/11] feat: enable address and leak detection on all debug builds (#82) --- CMakeLists.txt | 6 ++++++ src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp | 6 +++++- src/substrait/type/Type.cpp | 7 ++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bd7fe7e6..d4f78ef9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,12 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED True) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +add_compile_options($<$:-fsanitize=undefined>) +add_link_options($<$:-fsanitize=undefined>) + +add_compile_options($<$:-fsanitize=address>) +add_link_options($<$:-fsanitize=address>) + option( SUBSTRAIT_CPP_BUILD_TESTING "Enable substrait-cpp tests. This will enable all other build options automatically." diff --git a/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp index 403caa22..5b42d817 100644 --- a/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp @@ -109,11 +109,15 @@ ::substrait::proto::Type SubstraitPlanTypeVisitor::typeToProto( } case TypeKind::kVarchar: { auto varChar = - reinterpret_cast(&decodedType); + reinterpret_cast(&decodedType); if (varChar == nullptr) { break; } try { + if (!varChar->length()->isInteger()) { + errorListener_->addError(ctx->getStart(), "Missing varchar length."); + break; + } int32_t length = std::stoi(varChar->length()->value()); type.mutable_varchar()->set_length(length); } catch (...) { diff --git a/src/substrait/type/Type.cpp b/src/substrait/type/Type.cpp index 1fa21115..5a0e3cff 100644 --- a/src/substrait/type/Type.cpp +++ b/src/substrait/type/Type.cpp @@ -197,7 +197,12 @@ ParameterizedTypePtr ParameterizedType::decode( const auto& leftAngleBracketPos = matchingType.find('<'); if (leftAngleBracketPos == std::string::npos) { - bool nullable = matchingType.back() == '?'; + bool nullable; + if (matchingType.empty()) { + nullable = false; + } else { + nullable = matchingType.back() == '?'; + } // deal with type and with a question mask like "i32?". const auto& baseType = nullable ? matchingType = matchingType.substr(0, questionMaskPos) From 9aae89753dff7fbe83ff088f7bd2f316c6688a4c Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 22 Aug 2023 01:10:19 -0700 Subject: [PATCH 05/11] chore: Add EpsilonPrime to the list of code owners for Substrait textplans. (#86) --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a413bbe8..0bbf37fb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 * @westonpace +/src/substrait/textplan @EpsilonPrime From f8c383a5e6f43cf03367d9a30fb7cd0c87ab321b Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 23 Aug 2023 16:53:51 -0700 Subject: [PATCH 06/11] chore: Update version of substrait core to 0.32.0. (#84) --- .../converter/BasePlanProtoVisitor.cpp | 92 ++++++++++++++++++- .../textplan/converter/BasePlanProtoVisitor.h | 12 +++ .../converter/InitialPlanProtoVisitor.cpp | 24 +++++ .../textplan/converter/PipelineVisitor.cpp | 18 ++++ third_party/substrait | 2 +- 5 files changed, 146 insertions(+), 2 deletions(-) diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp index a5473641..b39b3464 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -411,6 +411,21 @@ std::any BasePlanProtoVisitor::visitWindowFunction( return std::nullopt; } +std::any BasePlanProtoVisitor::visitWindowRelFunction( + const ::substrait::proto::ConsistentPartitionWindowRel::WindowRelFunction& + function) { + for (const auto& arg : function.arguments()) { + visitFunctionArgument(arg); + } + for (const auto& arg : function.options()) { + visitFunctionOption(arg); + } + if (function.has_output_type()) { + visitType(function.output_type()); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen) { for (const auto& ifThenIf : ifthen.ifs()) { @@ -649,7 +664,6 @@ std::any BasePlanProtoVisitor::visitExpression( case ::substrait::proto::Expression::RexTypeCase::REX_TYPE_NOT_SET: break; } - // TODO -- Use an error listener instead. SUBSTRAIT_UNSUPPORTED( "Unsupported expression type encountered: " + std::to_string(expression.rex_type_case())); @@ -736,6 +750,25 @@ std::any BasePlanProtoVisitor::visitFieldReference( return std::nullopt; } +std::any BasePlanProtoVisitor::visitExpandField( + const ::substrait::proto::ExpandRel::ExpandField& field) { + switch (field.field_type_case()) { + case ::substrait::proto::ExpandRel_ExpandField::kSwitchingField: + for (const auto& switchingField : field.switching_field().duplicates()) { + visitExpression(switchingField); + } + break; + case ::substrait::proto::ExpandRel_ExpandField::kConsistentField: + if (field.has_consistent_field()) { + visitExpression(field.consistent_field()); + } + break; + case ::substrait::proto::ExpandRel_ExpandField::FIELD_TYPE_NOT_SET: + break; + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitReadRelation( const ::substrait::proto::ReadRel& relation) { if (relation.has_common()) { @@ -992,6 +1025,57 @@ std::any BasePlanProtoVisitor::visitMergeJoinRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitWindowRelation( + const ::substrait::proto::ConsistentPartitionWindowRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& func : relation.window_functions()) { + visitWindowRelFunction(func); + } + for (const auto& exp : relation.partition_expressions()) { + visitExpression(exp); + } + for (const auto& sort : relation.sorts()) { + visitSortField(sort); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExchangeRelation( + const ::substrait::proto::ExchangeRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExpandRelation( + const ::substrait::proto::ExpandRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& expandField : relation.fields()) { + visitExpandField(expandField); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitRelation( const ::substrait::proto::Rel& relation) { switch (relation.rel_type_case()) { @@ -1023,6 +1107,12 @@ std::any BasePlanProtoVisitor::visitRelation( return visitHashJoinRelation(relation.hash_join()); case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: return visitMergeJoinRelation(relation.merge_join()); + case ::substrait::proto::Rel::kWindow: + return visitWindowRelation(relation.window()); + case ::substrait::proto::Rel::kExchange: + return visitExchangeRelation(relation.exchange()); + case ::substrait::proto::Rel::kExpand: + return visitExpandRelation(relation.expand()); case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h index af1fb138..0eef3e00 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.h +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -4,6 +4,7 @@ #include +#include "substrait/proto/algebra.pb.h" #include "substrait/proto/plan.pb.h" namespace io::substrait::textplan { @@ -84,6 +85,9 @@ class BasePlanProtoVisitor { const ::substrait::proto::Expression::ScalarFunction& function); virtual std::any visitWindowFunction( const ::substrait::proto::Expression::WindowFunction& function); + virtual std::any visitWindowRelFunction( + const ::substrait::proto::ConsistentPartitionWindowRel::WindowRelFunction& + function); virtual std::any visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen); virtual std::any visitSwitchExpression( @@ -140,6 +144,8 @@ class BasePlanProtoVisitor { virtual std::any visitSortField(const ::substrait::proto::SortField& sort); virtual std::any visitFieldReference( const ::substrait::proto::Expression::FieldReference& ref); + virtual std::any visitExpandField( + const ::substrait::proto::ExpandRel::ExpandField& field); virtual std::any visitReadRelation( const ::substrait::proto::ReadRel& relation); @@ -168,6 +174,12 @@ class BasePlanProtoVisitor { const ::substrait::proto::HashJoinRel& relation); virtual std::any visitMergeJoinRelation( const ::substrait::proto::MergeJoinRel& relation); + virtual std::any visitWindowRelation( + const ::substrait::proto::ConsistentPartitionWindowRel& relation); + virtual std::any visitExchangeRelation( + const ::substrait::proto::ExchangeRel& relation); + virtual std::any visitExpandRelation( + const ::substrait::proto::ExpandRel& relation); virtual std::any visitRelation(const ::substrait::proto::Rel& relation); virtual std::any visitRelationRoot( diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index dfbdf38f..7fd3eb0a 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -76,6 +76,15 @@ void eraseInputs(::substrait::proto::Rel* relation) { relation->mutable_merge_join()->clear_left(); relation->mutable_merge_join()->clear_right(); break; + case ::substrait::proto::Rel::kWindow: + relation->mutable_window()->clear_input(); + break; + case ::substrait::proto::Rel::kExchange: + relation->mutable_exchange()->clear_input(); + break; + case ::substrait::proto::Rel::kExpand: + relation->mutable_expand()->clear_input(); + break; case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } @@ -112,6 +121,12 @@ ::google::protobuf::RepeatedField getOutputMapping( return relation.hash_join().common().emit().output_mapping(); case ::substrait::proto::Rel::kMergeJoin: return relation.merge_join().common().emit().output_mapping(); + case ::substrait::proto::Rel::kWindow: + return relation.window().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExchange: + return relation.exchange().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExpand: + return relation.expand().common().emit().output_mapping(); case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } @@ -521,6 +536,15 @@ void InitialPlanProtoVisitor::updateLocalSchema( relation.merge_join().left(), relation.merge_join().right()); break; + case ::substrait::proto::Rel::kWindow: + addFieldsToRelation(relationData, relation.window().input()); + break; + case ::substrait::proto::Rel::kExchange: + addFieldsToRelation(relationData, relation.exchange().input()); + break; + case ::substrait::proto::Rel::kExpand: + addFieldsToRelation(relationData, relation.expand().input()); + break; case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index d63a6f56..57d5dcea 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -107,6 +107,24 @@ std::any PipelineVisitor::visitRelation( relationData->newPipelines.push_back(rightSymbol); break; } + case ::substrait::proto::Rel::kWindow: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.window().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kExchange: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.exchange().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kExpand: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.expand().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/third_party/substrait b/third_party/substrait index 07e4feb5..31b99906 160000 --- a/third_party/substrait +++ b/third_party/substrait @@ -1 +1 @@ -Subproject commit 07e4feb5983478e7d0d95dc1d9b5e176685dbdc3 +Subproject commit 31b999060a6e014717f9ae3e6716986ad3066aaf From 01ac34bd5175e77a617d86cf0b8f602adfd1855a Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 23 Aug 2023 17:10:40 -0700 Subject: [PATCH 07/11] chore: Update antlr4 to version 4.13.0. (#85) --- src/substrait/textplan/parser/grammar/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/substrait/textplan/parser/grammar/CMakeLists.txt b/src/substrait/textplan/parser/grammar/CMakeLists.txt index 02430e0c..e04ea81a 100644 --- a/src/substrait/textplan/parser/grammar/CMakeLists.txt +++ b/src/substrait/textplan/parser/grammar/CMakeLists.txt @@ -11,15 +11,15 @@ set(GRAMMAR_DIR ${CMAKE_BINARY_DIR}/src/substrait/textplan/parser/grammar) # using /MD flag for antlr4_runtime (for Visual C++ compilers only) set(ANTLR4_WITH_STATIC_CRT OFF) -set(ANTLR4_TAG 4.12.0) +set(ANTLR4_TAG 4.13.0) set(ANTLR4_ZIP_REPOSITORY - https://github.com/antlr/antlr4/archive/refs/tags/4.12.0.zip) + https://github.com/antlr/antlr4/archive/refs/tags/${ANTLR4_TAG}.zip) include(ExternalAntlr4Cpp) include_directories(${ANTLR4_INCLUDE_DIRS}) -file(DOWNLOAD https://www.antlr.org/download/antlr-4.12.0-complete.jar +file(DOWNLOAD https://www.antlr.org/download/antlr-4.13.0-complete.jar "${GRAMMAR_DIR}/antlr.jar") set(ANTLR_EXECUTABLE "${GRAMMAR_DIR}/antlr.jar") find_package(ANTLR REQUIRED) From fe0ddad89c45f8c0914179d2f38db05850aae18e Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 24 Aug 2023 09:28:34 -0700 Subject: [PATCH 08/11] feat: enable roundtrip testing of textplan conversion (#81) --- src/substrait/textplan/tests/CMakeLists.txt | 80 ++++++++----------- .../textplan/tests/RoundtripTest.cpp | 12 +-- 2 files changed, 41 insertions(+), 51 deletions(-) diff --git a/src/substrait/textplan/tests/CMakeLists.txt b/src/substrait/textplan/tests/CMakeLists.txt index 24be6792..116ec4e7 100644 --- a/src/substrait/textplan/tests/CMakeLists.txt +++ b/src/substrait/textplan/tests/CMakeLists.txt @@ -19,50 +19,38 @@ add_test_case( gtest gtest_main) -option(SUBSTRAIT_CPP_ROUNDTRIP_TESTING - "Enable substrait-cpp textplan roundtrip tests." OFF) - -if(${SUBSTRAIT_CPP_ROUNDTRIP_TESTING}) - add_test_case( - round_trip_test - SOURCES - RoundtripTest.cpp - EXTRA_LINK_LIBS - substrait_textplan_converter - substrait_textplan_loader - substrait_textplan_normalizer - substrait_common - substrait_proto - parse_result_matchers - protobuf-matchers - fmt::fmt-header-only - gmock - gtest - gtest_main) - - cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH TEXTPLAN_SOURCE_DIR) - - add_custom_command( - TARGET round_trip_test - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E echo "Copying unit test data.." - COMMAND ${CMAKE_COMMAND} -E make_directory - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" - COMMAND - ${CMAKE_COMMAND} -E copy - "${TEXTPLAN_SOURCE_DIR}/converter/data/q6_first_stage.json" - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" - COMMAND ${CMAKE_COMMAND} -E copy "${TEXTPLAN_SOURCE_DIR}/data/*.json" - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/") - - message( - STATUS - "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data") -else() - - message( - STATUS - "Round trip testing is turned off. Add SUBSTRAIT_CPP_ROUNDTRIP_TESTING=on to enable." - ) +add_test_case( + substrait_textplan_round_trip_test + SOURCES + RoundtripTest.cpp + EXTRA_LINK_LIBS + substrait_textplan_converter + substrait_textplan_loader + substrait_textplan_normalizer + substrait_common + substrait_proto + parse_result_matchers + protobuf-matchers + fmt::fmt-header-only + gmock + gtest + gtest_main) -endif() +cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH TEXTPLAN_SOURCE_DIR) + +add_custom_command( + TARGET substrait_textplan_round_trip_test + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Copying unit test data.." + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" + COMMAND + ${CMAKE_COMMAND} -E copy + "${TEXTPLAN_SOURCE_DIR}/converter/data/q6_first_stage.json" + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" + COMMAND ${CMAKE_COMMAND} -E copy "${TEXTPLAN_SOURCE_DIR}/data/*.json" + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/") + +message( + STATUS "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" +) diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 09aeeca3..42f447c0 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -19,8 +19,7 @@ #include "substrait/textplan/tests/ParseResultMatchers.h" using ::protobuf_matchers::EqualsProto; -using ::protobuf_matchers::IgnoringFieldPaths; -using ::protobuf_matchers::Partially; +using ::protobuf_matchers::IgnoringFields; using ::testing::AllOf; namespace io::substrait::textplan { @@ -98,7 +97,11 @@ TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { ASSERT_THAT( result, ::testing::AllOf( - ParsesOk(), HasErrors({}), AsBinaryPlan(EqualsProto(normalizedPlan)))) + ParsesOk(), + HasErrors({}), + AsBinaryPlan(IgnoringFields( + {"substrait.proto.RelCommon.Emit.output_mapping"}, + EqualsProto(normalizedPlan))))) << std::endl << "Intermediate result:" << std::endl << addLineNumbers(outputText) << std::endl @@ -115,8 +118,7 @@ INSTANTIATE_TEST_SUITE_P( if (lastSlash != std::string::npos) { identifier = identifier.substr(lastSlash); } - if (identifier.length() > 5 && - identifier.substr(identifier.length() - 5) == ".json") { + if (endsWith(identifier, ".json")) { identifier = identifier.substr(0, identifier.length() - 5); } From 28c9e39f96190df5643028f21fb1ce5309eca438 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Thu, 24 Aug 2023 09:29:42 -0700 Subject: [PATCH 09/11] feat: switch datetime package to an external dependency (#83) By switching to an external dependency controlled by a central file we can more cleanly reference date::date and date::tz. We could do the same by "externally loading" a submodule but either way we end up downloading the datetime repo. --- src/substrait/textplan/converter/CMakeLists.txt | 2 ++ src/substrait/textplan/parser/CMakeLists.txt | 2 ++ third_party/CMakeLists.txt | 3 +-- third_party/datetime | 1 - third_party/datetime.cmake | 11 +++++++++++ 5 files changed, 16 insertions(+), 3 deletions(-) delete mode 160000 third_party/datetime create mode 100644 third_party/datetime.cmake diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 6f67271e..5ee27539 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +include(../../../../third_party/datetime.cmake) + set(TEXTPLAN_SRCS InitialPlanProtoVisitor.cpp InitialPlanProtoVisitor.h diff --git a/src/substrait/textplan/parser/CMakeLists.txt b/src/substrait/textplan/parser/CMakeLists.txt index 4c89fa83..fcc01413 100644 --- a/src/substrait/textplan/parser/CMakeLists.txt +++ b/src/substrait/textplan/parser/CMakeLists.txt @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +include(../../../../third_party/datetime.cmake) + add_subdirectory(grammar) add_library( diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7b8d8e60..2124c76d 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -5,8 +5,7 @@ if(NOT ${ABSL_INCLUDED_WITH_PROTOBUF}) add_subdirectory(abseil-cpp) endif() -set(BUILD_TZ_LIB ON) -add_subdirectory(datetime) +include(datetime.cmake) add_subdirectory(fmt) add_subdirectory(googletest) diff --git a/third_party/datetime b/third_party/datetime deleted file mode 160000 index cc4685a2..00000000 --- a/third_party/datetime +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cc4685a21e4a4fdae707ad1233c61bbaff241f93 diff --git a/third_party/datetime.cmake b/third_party/datetime.cmake new file mode 100644 index 00000000..f96aa06c --- /dev/null +++ b/third_party/datetime.cmake @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +include_guard(GLOBAL) + +set (BUILD_TZ_LIB ON CACHE BOOL "timezone library is a dependency" FORCE) +include(FetchContent) +FetchContent_Declare(date_src + GIT_REPOSITORY https://github.com/HowardHinnant/date.git + GIT_TAG v3.0.1 + ) +FetchContent_MakeAvailable(date_src) From b0a956f443597111f3e5ee5311343aba2db13973 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 25 Aug 2023 17:19:33 -0700 Subject: [PATCH 10/11] Handled review notes. --- include/substrait/common/Io.h | 16 +++++----- src/substrait/common/Io.cpp | 1 - src/substrait/common/tests/IoTest.cpp | 30 +++++++++++++------ src/substrait/textplan/CMakeLists.txt | 4 +-- src/substrait/textplan/StringManipulation.cpp | 12 -------- src/substrait/textplan/StringManipulation.h | 5 ---- .../textplan/converter/LoadBinary.cpp | 3 +- .../textplan/converter/SaveBinary.cpp | 16 ++++++---- src/substrait/textplan/parser/LoadText.cpp | 4 ++- 9 files changed, 46 insertions(+), 45 deletions(-) diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h index a2c18726..a53697d8 100644 --- a/include/substrait/common/Io.h +++ b/include/substrait/common/Io.h @@ -20,7 +20,7 @@ enum class PlanFileFormat { }; /* - * \\brief Loads a Substrait plan of any format from the given file. + * \brief Loads a Substrait plan of any format from the given file. * * loadPlan determines which file type the specified file is and then calls * the appropriate load/parse method to consume it preserving any error @@ -29,8 +29,8 @@ enum class PlanFileFormat { * This will load the plan into memory and then convert it consuming twice the * amount of memory that it consumed on disk. * - * \\param input_filename The filename containing the plan to convert. - * \\return If loading was successful, returns a plan. If loading was not + * \param input_filename The filename containing the plan to convert. + * \return If loading was successful, returns a plan. If loading was not * successful this is a status containing a list of parse errors in the status's * message. */ @@ -38,7 +38,7 @@ absl::StatusOr<::substrait::proto::Plan> loadPlan( std::string_view input_filename); /* - * \\brief Writes the provided plan to disk. + * \brief Writes the provided plan to disk. * * savePlan writes the provided plan in the specified format to the specified * location. @@ -47,10 +47,10 @@ absl::StatusOr<::substrait::proto::Plan> loadPlan( * format as the original plan as well as the annotated parse tree will need to * reside in memory during the process. * - * \\param plan - * \\param output_filename - * \\param format - * \\return + * \param plan + * \param output_filename + * \param format + * \return */ absl::Status savePlan( const ::substrait::proto::Plan& plan, diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp index 4c5b1c24..af06066c 100644 --- a/src/substrait/common/Io.cpp +++ b/src/substrait/common/Io.cpp @@ -55,7 +55,6 @@ absl::StatusOr<::substrait::proto::Plan> loadPlan( case PlanFileFormat::kText: return textplan::loadFromText(*contentOrError); } - return absl::UnimplementedError("Unexpected encoding requested."); } absl::Status savePlan( diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp index bf8dd00d..09a7283e 100644 --- a/src/substrait/common/tests/IoTest.cpp +++ b/src/substrait/common/tests/IoTest.cpp @@ -2,6 +2,8 @@ #include "substrait/common/Io.h" +#include + #include #include #include @@ -42,23 +44,33 @@ TEST_F(IoTest, LoadMissingFile) { class SaveAndLoadTestFixture : public ::testing::TestWithParam { public: - ~SaveAndLoadTestFixture() override { - for (const auto& filename : testFiles_) { - unlink(filename.c_str()); + void SetUp() override { + testFileDirectory_ = std::filesystem::temp_directory_path() / + std::filesystem::path("my_temp_dir"); + + if (!std::filesystem::create_directory(testFileDirectory_)) { + std::cerr << "Failed to create temporary directory." << std::endl; + testFileDirectory_.clear(); + } + } + + void TearDown() override { + if (!testFileDirectory_.empty()) { + std::filesystem::remove_all(testFileDirectory_); } } - void registerCleanup(const char* filename) { - testFiles_.emplace_back(filename); + static std::string makeTempFileName() { + static int tempFileNum = 0; + return "testfile" + std::to_string(++tempFileNum); } - private: - std::vector testFiles_; + protected: + std::string testFileDirectory_; }; TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { - auto tempFilename = std::tmpnam(nullptr); - registerCleanup(tempFilename); + auto tempFilename = testFileDirectory_ + "/" + makeTempFileName(); PlanFileFormat encoding = GetParam(); ::substrait::proto::Plan plan; diff --git a/src/substrait/textplan/CMakeLists.txt b/src/substrait/textplan/CMakeLists.txt index 1e8b7dbb..eda37d58 100644 --- a/src/substrait/textplan/CMakeLists.txt +++ b/src/substrait/textplan/CMakeLists.txt @@ -20,10 +20,10 @@ add_library(error_listener SubstraitErrorListener.cpp SubstraitErrorListener.h) add_library(parse_result ParseResult.cpp ParseResult.h) -add_dependencies(symbol_table substrait_proto substrait_common +add_dependencies(symbol_table substrait_proto substrait_common absl::strings fmt::fmt-header-only) -target_link_libraries(symbol_table fmt::fmt-header-only +target_link_libraries(symbol_table fmt::fmt-header-only absl::strings substrait_textplan_converter) # Provide access to the generated protobuffer headers hierarchy. diff --git a/src/substrait/textplan/StringManipulation.cpp b/src/substrait/textplan/StringManipulation.cpp index ce37c9bb..cb11e53a 100644 --- a/src/substrait/textplan/StringManipulation.cpp +++ b/src/substrait/textplan/StringManipulation.cpp @@ -19,16 +19,4 @@ bool endsWith(std::string_view haystack, std::string_view needle) { haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; } -std::string joinLines( - std::vector lines, - std::string_view separator) { - auto concatWithSeparator = [separator](std::string a, const std::string& b) { - return std::move(a) + std::string(separator) + b; - }; - - auto result = std::accumulate( - std::next(lines.begin()), lines.end(), lines[0], concatWithSeparator); - return result; -} - } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/StringManipulation.h b/src/substrait/textplan/StringManipulation.h index d0d602e6..8edf7ea5 100644 --- a/src/substrait/textplan/StringManipulation.h +++ b/src/substrait/textplan/StringManipulation.h @@ -14,9 +14,4 @@ bool startsWith(std::string_view haystack, std::string_view needle); // Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle); -// Joins a vector of strings into a single string separated by separator. -std::string joinLines( - std::vector lines, - std::string_view separator = "\n"); - } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.cpp b/src/substrait/textplan/converter/LoadBinary.cpp index 65e5694c..c5d9f4ce 100644 --- a/src/substrait/textplan/converter/LoadBinary.cpp +++ b/src/substrait/textplan/converter/LoadBinary.cpp @@ -2,6 +2,7 @@ #include "substrait/textplan/converter/LoadBinary.h" +#include #include #include #include @@ -83,7 +84,7 @@ absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( parser.RecordErrorsTo(&collector); if (!parser.ParseFromString(text, &plan)) { auto errors = collector.getErrors(); - return absl::InternalError(joinLines(errors)); + return absl::InternalError(absl::StrJoin(errors, "")); } return plan; } diff --git a/src/substrait/textplan/converter/SaveBinary.cpp b/src/substrait/textplan/converter/SaveBinary.cpp index e7347269..c8dd6c07 100644 --- a/src/substrait/textplan/converter/SaveBinary.cpp +++ b/src/substrait/textplan/converter/SaveBinary.cpp @@ -2,6 +2,7 @@ #include "substrait/textplan/converter/SaveBinary.h" +#include #include #include #include @@ -34,8 +35,10 @@ absl::Status savePlanToBinary( return ::absl::UnknownError("Failed to write plan to stream."); } + if (!stream->Close()) { + return absl::AbortedError("Failed to close file descriptor."); + } delete stream; - close(outputFileDescriptor); return absl::OkStatus(); } @@ -54,11 +57,10 @@ absl::Status savePlanToJson( return absl::UnknownError("Failed to save plan as a JSON protobuf."); } stream << output; + stream.close(); if (stream.fail()) { return absl::UnknownError("Failed to write the plan as a JSON protobuf."); } - - stream.close(); return absl::OkStatus(); } @@ -74,13 +76,13 @@ absl::Status savePlanToText( auto result = parseBinaryPlan(plan); auto errors = result.getAllErrors(); if (!errors.empty()) { - return absl::UnknownError(joinLines(errors)); + return absl::UnknownError(absl::StrJoin(errors, "")); } stream << SymbolTablePrinter::outputToText(result.getSymbolTable()); + stream.close(); if (stream.fail()) { return absl::UnknownError("Failed to write the plan as text."); } - stream.close(); return absl::OkStatus(); } @@ -101,8 +103,10 @@ absl::Status savePlanToProtoText( return absl::UnknownError("Failed to save plan as a text protobuf."); } + if (!stream->Close()) { + return absl::AbortedError("Failed to close file descriptor."); + } delete stream; - close(outputFileDescriptor); return absl::OkStatus(); } diff --git a/src/substrait/textplan/parser/LoadText.cpp b/src/substrait/textplan/parser/LoadText.cpp index 0850befe..c76a7b31 100644 --- a/src/substrait/textplan/parser/LoadText.cpp +++ b/src/substrait/textplan/parser/LoadText.cpp @@ -2,6 +2,8 @@ #include "substrait/textplan/parser/LoadText.h" +#include + #include "substrait/proto/plan.pb.h" #include "substrait/textplan/StringManipulation.h" #include "substrait/textplan/SymbolTablePrinter.h" @@ -14,7 +16,7 @@ absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text) { auto parseResult = io::substrait::textplan::parseStream(stream); if (!parseResult.successful()) { auto errors = parseResult.getAllErrors(); - return absl::UnknownError(joinLines(errors)); + return absl::UnknownError(absl::StrJoin(errors, "")); } return SymbolTablePrinter::outputToBinaryPlan(parseResult.getSymbolTable()); From 77ecaf3d39c6a1ec367155106b2a24556316fc99 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 25 Aug 2023 17:54:38 -0700 Subject: [PATCH 11/11] Added better error detection. --- src/substrait/common/tests/IoTest.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp index 09a7283e..183594bd 100644 --- a/src/substrait/common/tests/IoTest.cpp +++ b/src/substrait/common/tests/IoTest.cpp @@ -49,14 +49,16 @@ class SaveAndLoadTestFixture : public ::testing::TestWithParam { std::filesystem::path("my_temp_dir"); if (!std::filesystem::create_directory(testFileDirectory_)) { - std::cerr << "Failed to create temporary directory." << std::endl; + ASSERT_TRUE(false) << "Failed to create temporary directory."; testFileDirectory_.clear(); } } void TearDown() override { if (!testFileDirectory_.empty()) { - std::filesystem::remove_all(testFileDirectory_); + std::error_code err; + std::filesystem::remove_all(testFileDirectory_, err); + ASSERT_FALSE(err) << err.message(); } }