diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h new file mode 100644 index 00000000..4227f7a3 --- /dev/null +++ b/include/substrait/common/Io.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "absl/status/statusor.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait { + +enum PlanFileEncoding { + kBinary = 0, + kJson = 1, + kProtoText = 2, + kText = 3, +}; + +// Loads a Substrait plan consisting of any encoding type from the given file. +absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( + std::string_view input_filename); + +// Writes the provided plan to the specified location with the specified +// encoding type. +[[maybe_unused]] absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileEncoding encoding); + +} // namespace io::substrait diff --git a/src/substrait/common/CMakeLists.txt b/src/substrait/common/CMakeLists.txt index 846d077f..dc05a11e 100644 --- a/src/substrait/common/CMakeLists.txt +++ b/src/substrait/common/CMakeLists.txt @@ -1,9 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 add_library(substrait_common Exceptions.cpp) - target_link_libraries(substrait_common fmt::fmt-header-only) +add_library(substrait_io Io.cpp) +add_dependencies( + substrait_io + substrait_proto + substrait_textplan_converter + substrait_textplan_loader + fmt::fmt-header-only + absl::status + absl::statusor) +target_link_libraries(substrait_io substrait_proto substrait_textplan_converter + substrait_textplan_loader absl::status absl::statusor) + if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp new file mode 100644 index 00000000..d017a58c --- /dev/null +++ b/src/substrait/common/Io.cpp @@ -0,0 +1,77 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/converter/LoadBinary.h" +#include "substrait/textplan/converter/SaveBinary.h" +#include "substrait/textplan/parser/LoadText.h" + +namespace io::substrait { + +namespace { + +const std::regex kIsJson(R"(("extensionUris"|"extensions"|"relations"))"); +const std::regex kIsProtoText( + R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))"); +const std::regex kIsText( + R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)"); + +PlanFileEncoding detectEncoding(std::string_view content) { + if (std::regex_search(content.begin(), content.end(), kIsJson)) { + return kJson; + } + if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { + return kProtoText; + } + if (std::regex_search(content.begin(), content.end(), kIsText)) { + return kText; + } + return kBinary; +} + +} // namespace + +absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( + std::string_view input_filename) { + auto contentOrError = textplan::readFromFile(input_filename.data()); + if (!contentOrError.ok()) { + return contentOrError.status(); + } + + auto encoding = detectEncoding(*contentOrError); + absl::StatusOr<::substrait::proto::Plan> planOrError; + switch (encoding) { + case kBinary: + return textplan::loadFromBinary(*contentOrError); + case kJson: + return textplan::loadFromJson(*contentOrError); + case kProtoText: + return textplan::loadFromProtoText(*contentOrError); + case kText: + return textplan::loadFromText(*contentOrError); + } + return absl::UnimplementedError("Unexpected encoding requested."); +} + +absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileEncoding encoding) { + switch (encoding) { + case kBinary: + return textplan::savePlanToBinary(plan, output_filename); + case kJson: + return textplan::savePlanToJson(plan, output_filename); + case kProtoText: + return textplan::savePlanToProtoText(plan, output_filename); + case kText: + return textplan::savePlanToText(plan, output_filename); + } + return absl::UnimplementedError("Unexpected encoding requested."); +} + +} // namespace io::substrait diff --git a/src/substrait/common/tests/CMakeLists.txt b/src/substrait/common/tests/CMakeLists.txt index 36c5f160..88b63b8f 100644 --- a/src/substrait/common/tests/CMakeLists.txt +++ b/src/substrait/common/tests/CMakeLists.txt @@ -9,3 +9,13 @@ add_test_case( substrait_common gtest gtest_main) + +add_test_case( + substrait_io_test + SOURCES + IoTest.cpp + EXTRA_LINK_LIBS + substrait_io + protobuf-matchers + gtest + gtest_main) diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp new file mode 100644 index 00000000..a9e7f13f --- /dev/null +++ b/src/substrait/common/tests/IoTest.cpp @@ -0,0 +1,104 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include +#include +#include +#include + +using ::protobuf_matchers::EqualsProto; +using ::protobuf_matchers::Partially; + +namespace io::substrait { + +namespace { + +constexpr const char* planFileEncodingToString(PlanFileEncoding e) noexcept { + switch (e) { + case PlanFileEncoding::kBinary: + return "kBinary"; + case PlanFileEncoding::kJson: + return "kJson"; + case PlanFileEncoding::kProtoText: + return "kProtoText"; + case PlanFileEncoding::kText: + return "kText"; + } + return "IMPOSSIBLE"; +} + +} // namespace + +class IoTest : public ::testing::Test {}; + +TEST_F(IoTest, LoadMissingFile) { + auto result = + ::io::substrait::loadPlanWithUnknownEncoding("non-existent-file"); + ASSERT_FALSE(result.ok()); + ASSERT_THAT( + result.status().message(), + ::testing::ContainsRegex("Failed to open file non-existent-file")); +} + +class SaveAndLoadTestFixture + : public ::testing::TestWithParam { + public: + ~SaveAndLoadTestFixture() override { + for (const auto& filename : testFiles_) { + unlink(filename.c_str()); + } + } + + void registerCleanup(const char* filename) { + testFiles_.emplace_back(filename); + } + + private: + std::vector testFiles_; +}; + +TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { + auto tempFilename = std::tmpnam(nullptr); + registerCleanup(tempFilename); + PlanFileEncoding encoding = GetParam(); + + ::substrait::proto::Plan plan; + auto root = plan.add_relations()->mutable_root(); + auto read = root->mutable_input()->mutable_read(); + read->mutable_common()->mutable_direct(); + read->mutable_named_table()->add_names("table_name"); + auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); + ASSERT_TRUE(status.ok()) << "Save failed.\n" << status; + + auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename); + ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status(); + ASSERT_THAT( + *result, + Partially(EqualsProto<::substrait::proto::Plan>( + R"(relations { + root { + input { + read { + common { + direct { + } + } + named_table { + names: "table_name" + } + } + } + } + })"))); +} + +INSTANTIATE_TEST_SUITE_P( + SaveAndLoadTests, + SaveAndLoadTestFixture, + testing::Values(kBinary, kJson, kProtoText, kText), + [](const testing::TestParamInfo& info) { + return planFileEncodingToString(info.param); + }); + +} // namespace io::substrait diff --git a/src/substrait/textplan/StringManipulation.cpp b/src/substrait/textplan/StringManipulation.cpp index eac3c56a..ce37c9bb 100644 --- a/src/substrait/textplan/StringManipulation.cpp +++ b/src/substrait/textplan/StringManipulation.cpp @@ -2,18 +2,33 @@ #include "StringManipulation.h" +#include +#include +#include +#include + namespace io::substrait::textplan { -// Yields true if the string 'haystack' starts with the string 'needle'. bool startsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(0, needle.size()) == needle; } -// Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; } +std::string joinLines( + std::vector lines, + std::string_view separator) { + auto concatWithSeparator = [separator](std::string a, const std::string& b) { + return std::move(a) + std::string(separator) + b; + }; + + auto result = std::accumulate( + std::next(lines.begin()), lines.end(), lines[0], concatWithSeparator); + return result; +} + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/StringManipulation.h b/src/substrait/textplan/StringManipulation.h index 9c24418f..d0d602e6 100644 --- a/src/substrait/textplan/StringManipulation.h +++ b/src/substrait/textplan/StringManipulation.h @@ -2,7 +2,9 @@ #pragma once +#include #include +#include namespace io::substrait::textplan { @@ -12,4 +14,9 @@ bool startsWith(std::string_view haystack, std::string_view needle); // Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle); +// Joins a vector of strings into a single string separated by separator. +std::string joinLines( + std::vector lines, + std::string_view separator = "\n"); + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 19f96188..1b460624 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) { auto subtype = ANY_CAST(SourceType, info.subtype); switch (subtype) { case SourceType::kNamedTable: { - auto table = - ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob); + if (info.blob.has_value()) { + // We are using the proto as is in lieu of a disciplined structure. + auto table = ANY_CAST( + const ::substrait::proto::ReadRel_NamedTable*, info.blob); + text << "source named_table " << info.name << " {\n"; + text << " names = [\n"; + for (const auto& name : table->names()) { + text << " \"" << name << "\",\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + // We are using the new style data structure. text << "source named_table " << info.name << " {\n"; text << " names = [\n"; - for (const auto& name : table->names()) { - text << " \"" << name << "\",\n"; + for (const auto& sym : + symbolTable.lookupSymbolsByLocation(info.location)) { + if (sym->type == SymbolType::kSourceDetail) { + text << " \"" << sym->name << "\",\n"; + } } text << " ]\n"; text << "}\n"; hasPreviousText = true; + break; } case SourceType::kLocalFiles: { diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 41f3b9ce..6f67271e 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -11,6 +11,8 @@ set(TEXTPLAN_SRCS PlanPrinterVisitor.h LoadBinary.cpp LoadBinary.h + SaveBinary.cpp + SaveBinary.h ParseBinary.cpp ParseBinary.h) @@ -20,10 +22,14 @@ target_link_libraries( substrait_textplan_converter substrait_common substrait_expression + substrait_io substrait_proto symbol_table error_listener - date::date) + date::date + fmt::fmt-header-only + absl::status + absl::statusor) if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) diff --git a/src/substrait/textplan/converter/LoadBinary.cpp b/src/substrait/textplan/converter/LoadBinary.cpp index 787ab054..65e5694c 100644 --- a/src/substrait/textplan/converter/LoadBinary.cpp +++ b/src/substrait/textplan/converter/LoadBinary.cpp @@ -2,20 +2,19 @@ #include "substrait/textplan/converter/LoadBinary.h" +#include #include #include #include - #include #include -#include #include #include #include #include -#include "substrait/common/Exceptions.h" #include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" namespace io::substrait::textplan { @@ -39,24 +38,23 @@ class StringErrorCollector : public google::protobuf::io::ErrorCollector { } // namespace -std::string readFromFile(std::string_view msgPath) { +absl::StatusOr readFromFile(std::string_view msgPath) { std::ifstream textFile(std::string{msgPath}); if (textFile.fail()) { - auto currdir = std::filesystem::current_path().string(); - SUBSTRAIT_FAIL( - "Failed to open file {} when running in {}: {}", - msgPath, - currdir, - strerror(errno)); + auto currDir = std::filesystem::current_path().string(); + return absl::ErrnoToStatus( + errno, + fmt::format( + "Failed to open file {} when running in {}", msgPath, currDir)); } std::stringstream buffer; buffer << textFile.rdbuf(); return buffer.str(); } -PlanOrErrors loadFromJson(std::string_view json) { +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json) { if (json.empty()) { - return PlanOrErrors({"Provided JSON string was empty."}); + return absl::InternalError("Provided JSON string was empty."); } std::string_view usableJson = json; if (json[0] == '#') { @@ -71,21 +69,32 @@ PlanOrErrors loadFromJson(std::string_view json) { std::string{usableJson}, &plan); if (!status.ok()) { std::string msg{status.message()}; - return PlanOrErrors( - {fmt::format("Failed to parse Substrait JSON: {}", msg)}); + return absl::InternalError( + fmt::format("Failed to parse Substrait JSON: {}", msg)); } - return PlanOrErrors(plan); + return plan; } -PlanOrErrors loadFromText(const std::string& text) { +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text) { ::substrait::proto::Plan plan; ::google::protobuf::TextFormat::Parser parser; StringErrorCollector collector; parser.RecordErrorsTo(&collector); if (!parser.ParseFromString(text, &plan)) { - return PlanOrErrors(collector.getErrors()); + auto errors = collector.getErrors(); + return absl::InternalError(joinLines(errors)); + } + return plan; +} + +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes) { + ::substrait::proto::Plan plan; + if (!plan.ParseFromString(bytes)) { + return absl::InternalError("Failed to parse as a binary Substrait plan."); } - return PlanOrErrors(plan); + return plan; } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.h b/src/substrait/textplan/converter/LoadBinary.h index b0b73a90..a4d4e38f 100644 --- a/src/substrait/textplan/converter/LoadBinary.h +++ b/src/substrait/textplan/converter/LoadBinary.h @@ -2,12 +2,9 @@ #pragma once +#include #include #include -#include -#include - -#include "substrait/proto/plan.pb.h" namespace substrait::proto { class Plan; @@ -15,41 +12,21 @@ class Plan; namespace io::substrait::textplan { -// PlanOrErrors behaves similarly to abseil::StatusOr. -class PlanOrErrors { - public: - explicit PlanOrErrors(::substrait::proto::Plan plan) - : plan_(std::move(plan)){}; - explicit PlanOrErrors(std::vector errors) - : errors_(std::move(errors)){}; - - bool ok() { - return errors_.empty(); - } - - const ::substrait::proto::Plan& operator*() { - return plan_; - } - - const std::vector& errors() { - return errors_; - } - - private: - ::substrait::proto::Plan plan_; - std::vector errors_; -}; - // Read the contents of a file from disk. -// Throws an exception if file cannot be read. -std::string readFromFile(std::string_view msgPath); +absl::StatusOr readFromFile(std::string_view msgPath); // Reads a plan from a json-encoded text proto. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromJson(std::string_view json); +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json); // Reads a plan encoded as a text protobuf. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromText(const std::string& text); +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text); + +// Reads a plan serialized as a binary protobuf. +// Returns a list of errors if the file cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes); } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/README.md b/src/substrait/textplan/converter/README.md index 607c0652..53a0e98e 100644 --- a/src/substrait/textplan/converter/README.md +++ b/src/substrait/textplan/converter/README.md @@ -1,7 +1,7 @@ # Using the Plan Converter Tool -The plan converter takes any number of JSON encoded Substrait plan files and converts them into the Substrait Text Plan -format. +The plan converter takes any number of Substrait plan files of any format and +converts them into the Substrait Text Plan format. ## Usage: ``` diff --git a/src/substrait/textplan/converter/SaveBinary.cpp b/src/substrait/textplan/converter/SaveBinary.cpp new file mode 100644 index 00000000..e7347269 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.cpp @@ -0,0 +1,109 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/SaveBinary.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/converter/ParseBinary.h" + +namespace io::substrait::textplan { + +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!plan.SerializeToZeroCopyStream(stream)) { + return ::absl::UnknownError("Failed to write plan to stream."); + } + + delete stream; + close(outputFileDescriptor); + return absl::OkStatus(); +} + +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + std::string output; + auto status = ::google::protobuf::util::MessageToJsonString(plan, &output); + if (!status.ok()) { + return absl::UnknownError("Failed to save plan as a JSON protobuf."); + } + stream << output; + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as a JSON protobuf."); + } + + stream.close(); + return absl::OkStatus(); +} + +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + auto result = parseBinaryPlan(plan); + auto errors = result.getAllErrors(); + if (!errors.empty()) { + return absl::UnknownError(joinLines(errors)); + } + stream << SymbolTablePrinter::outputToText(result.getSymbolTable()); + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as text."); + } + stream.close(); + return absl::OkStatus(); +} + +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!::google::protobuf::TextFormat::Print(plan, stream)) { + return absl::UnknownError("Failed to save plan as a text protobuf."); + } + + delete stream; + close(outputFileDescriptor); + return absl::OkStatus(); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/SaveBinary.h b/src/substrait/textplan/converter/SaveBinary.h new file mode 100644 index 00000000..d9158773 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.h @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "absl/status/status.h" + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Serializes a plan to disk as a binary protobuf. +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a JSON-encoded protobuf. +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Calls the converter to store a plan on disk as a text-based substrait plan. +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a text-encoded protobuf. +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp index 9c3d652b..67d6ca82 100644 --- a/src/substrait/textplan/converter/Tool.cpp +++ b/src/substrait/textplan/converter/Tool.cpp @@ -6,6 +6,7 @@ #include +#include "substrait/common/Io.h" #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" @@ -13,14 +14,10 @@ namespace io::substrait::textplan { namespace { -void convertJsonToText(const char* filename) { - std::string json = readFromFile(filename); - auto planOrError = loadFromJson(json); +void convertPlanToText(const char* filename) { + auto planOrError = loadPlanWithUnknownEncoding(filename); if (!planOrError.ok()) { - std::cerr << "An error occurred while reading: " << filename << std::endl; - for (const auto& err : planOrError.errors()) { - std::cerr << err << std::endl; - } + std::cerr << planOrError.status() << std::endl; return; } @@ -46,7 +43,7 @@ int main(int argc, char* argv[]) { #ifdef _WIN32 for (int currArg = 1; currArg < argc; currArg++) { printf("===== %s =====\n", argv[currArg]); - io::substrait::textplan::convertJsonToText(argv[currArg]); + io::substrait::textplan::convertPlanToText(argv[currArg]); } #else for (int currArg = 1; currArg < argc; currArg++) { @@ -54,7 +51,7 @@ int main(int argc, char* argv[]) { glob(argv[currArg], GLOB_TILDE, nullptr, &globResult); for (size_t i = 0; i < globResult.gl_pathc; i++) { printf("===== %s =====\n", globResult.gl_pathv[i]); - io::substrait::textplan::convertJsonToText(globResult.gl_pathv[i]); + io::substrait::textplan::convertPlanToText(globResult.gl_pathv[i]); } } #endif diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index cde3c92a..3b0ae97f 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -597,9 +597,10 @@ std::vector getTestCases() { TEST_P(BinaryToTextPlanConverterTestFixture, Parse) { auto [name, input, matcher] = GetParam(); - auto planOrError = loadFromText(input); + auto planOrError = loadFromProtoText(input); if (!planOrError.ok()) { - ParseResult result(SymbolTable(), planOrError.errors(), {}); + ParseResult result( + SymbolTable(), {std::string(planOrError.status().message())}, {}); ASSERT_THAT(result, matcher); return; } @@ -627,13 +628,15 @@ INSTANTIATE_TEST_SUITE_P( class BinaryToTextPlanConversionTest : public ::testing::Test {}; TEST_F(BinaryToTextPlanConversionTest, FullSample) { - std::string json = readFromFile("data/q6_first_stage.json"); - auto planOrError = loadFromJson(json); + auto jsonOrError = readFromFile("data/q6_first_stage.json"); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); ASSERT_TRUE(planOrError.ok()); auto plan = *planOrError; EXPECT_THAT(plan.extensions_size(), ::testing::Eq(7)); - std::string expectedOutput = readFromFile("data/q6_first_stage.golden.splan"); + auto expectedOutputOrError = readFromFile("data/q6_first_stage.golden.splan"); + ASSERT_TRUE(expectedOutputOrError.ok()); auto result = parseBinaryPlan(plan); auto symbols = result.getSymbolTable().getSymbols(); @@ -668,7 +671,7 @@ TEST_F(BinaryToTextPlanConversionTest, FullSample) { SymbolType::kSource, SymbolType::kSchema, }), - WhenSerialized(EqSquashingWhitespace(expectedOutput)))) + WhenSerialized(EqSquashingWhitespace(*expectedOutputOrError)))) << result.getSymbolTable().toDebugString(); } diff --git a/src/substrait/textplan/parser/CMakeLists.txt b/src/substrait/textplan/parser/CMakeLists.txt index 498d572c..4c89fa83 100644 --- a/src/substrait/textplan/parser/CMakeLists.txt +++ b/src/substrait/textplan/parser/CMakeLists.txt @@ -12,6 +12,8 @@ add_library( SubstraitPlanRelationVisitor.h SubstraitPlanTypeVisitor.cpp SubstraitPlanTypeVisitor.h + LoadText.cpp + LoadText.h ParseText.cpp ParseText.h SubstraitParserErrorListener.cpp) @@ -26,7 +28,9 @@ target_link_libraries( textplan_grammar fmt::fmt-header-only date::date - date::date-tz) + date::date-tz + absl::status + absl::statusor) add_executable(planparser Tool.cpp) diff --git a/src/substrait/textplan/parser/LoadText.cpp b/src/substrait/textplan/parser/LoadText.cpp new file mode 100644 index 00000000..0850befe --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.cpp @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/parser/LoadText.h" + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/parser/ParseText.h" + +namespace io::substrait::textplan { + +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text) { + auto stream = loadTextString(text); + auto parseResult = io::substrait::textplan::parseStream(stream); + if (!parseResult.successful()) { + auto errors = parseResult.getAllErrors(); + return absl::UnknownError(joinLines(errors)); + } + + return SymbolTablePrinter::outputToBinaryPlan(parseResult.getSymbolTable()); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/LoadText.h b/src/substrait/textplan/parser/LoadText.h new file mode 100644 index 00000000..3b2f900b --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Reads a plan encoded as a text protobuf. +// Returns a list of errors if the text cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 3a19d3cc..a4ba32fc 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: Apache-2.0 */ -#include "ParseText.h" +#include "substrait/textplan/parser/ParseText.h" #include #include diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index a8273f12..4a6d5567 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -974,14 +974,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( case SourceType::kNamedTable: { auto* source = parentRelationData->relation.mutable_read()->mutable_named_table(); - for (const auto& sym : *symbolTable_) { - if (sym.type != SymbolType::kSourceDetail) { - continue; - } - if (sym.location != symbol->location) { + for (const auto& sym : + symbolTable_->lookupSymbolsByLocation(symbol->location)) { + if (sym->type != SymbolType::kSourceDetail) { continue; } - source->add_names(sym.name); + source->add_names(sym->name); } break; } diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 3b22071c..09aeeca3 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -70,12 +70,12 @@ std::vector getTestCases() { TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { auto filename = GetParam(); - std::string json = readFromFile(filename); - auto planOrErrors = loadFromJson(json); - std::vector errors = planOrErrors.errors(); - ASSERT_THAT(errors, ::testing::ElementsAre()); + auto jsonOrError = readFromFile(filename); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); + ASSERT_TRUE(planOrError.ok()); - auto plan = *planOrErrors; + auto plan = *planOrError; auto textResult = parseBinaryPlan(plan); auto textSymbols = textResult.getSymbolTable().getSymbols();