diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a413bbe8..0bbf37fb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 * @westonpace +/src/substrait/textplan @EpsilonPrime diff --git a/CMakeLists.txt b/CMakeLists.txt index bd7fe7e6..d4f78ef9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,12 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED True) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +add_compile_options($<$:-fsanitize=undefined>) +add_link_options($<$:-fsanitize=undefined>) + +add_compile_options($<$:-fsanitize=address>) +add_link_options($<$:-fsanitize=address>) + option( SUBSTRAIT_CPP_BUILD_TESTING "Enable substrait-cpp tests. This will enable all other build options automatically." diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h new file mode 100644 index 00000000..a53697d8 --- /dev/null +++ b/include/substrait/common/Io.h @@ -0,0 +1,60 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +#include "absl/status/statusor.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait { + +/* + * \brief The four different ways plans can be represented on disk. + */ +enum class PlanFileFormat { + kBinary = 0, + kJson = 1, + kProtoText = 2, + kText = 3, +}; + +/* + * \brief Loads a Substrait plan of any format from the given file. + * + * loadPlan determines which file type the specified file is and then calls + * the appropriate load/parse method to consume it preserving any error + * messages. + * + * This will load the plan into memory and then convert it consuming twice the + * amount of memory that it consumed on disk. + * + * \param input_filename The filename containing the plan to convert. + * \return If loading was successful, returns a plan. If loading was not + * successful this is a status containing a list of parse errors in the status's + * message. + */ +absl::StatusOr<::substrait::proto::Plan> loadPlan( + std::string_view input_filename); + +/* + * \brief Writes the provided plan to disk. + * + * savePlan writes the provided plan in the specified format to the specified + * location. + * + * This routine will consume more memory during the conversion to the text + * format as the original plan as well as the annotated parse tree will need to + * reside in memory during the process. + * + * \param plan + * \param output_filename + * \param format + * \return + */ +absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileFormat format); + +} // namespace io::substrait diff --git a/src/substrait/common/CMakeLists.txt b/src/substrait/common/CMakeLists.txt index 846d077f..dc05a11e 100644 --- a/src/substrait/common/CMakeLists.txt +++ b/src/substrait/common/CMakeLists.txt @@ -1,9 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 add_library(substrait_common Exceptions.cpp) - target_link_libraries(substrait_common fmt::fmt-header-only) +add_library(substrait_io Io.cpp) +add_dependencies( + substrait_io + substrait_proto + substrait_textplan_converter + substrait_textplan_loader + fmt::fmt-header-only + absl::status + absl::statusor) +target_link_libraries(substrait_io substrait_proto substrait_textplan_converter + substrait_textplan_loader absl::status absl::statusor) + if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp new file mode 100644 index 00000000..af06066c --- /dev/null +++ b/src/substrait/common/Io.cpp @@ -0,0 +1,77 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/converter/LoadBinary.h" +#include "substrait/textplan/converter/SaveBinary.h" +#include "substrait/textplan/parser/LoadText.h" + +namespace io::substrait { + +namespace { + +const std::regex kIsJson( + R"(("extensionUris"|"extension_uris"|"extensions"|"relations"))"); +const std::regex kIsProtoText( + R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))"); +const std::regex kIsText( + R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)"); + +PlanFileFormat detectFormat(std::string_view content) { + if (std::regex_search(content.begin(), content.end(), kIsJson)) { + return PlanFileFormat::kJson; + } + if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { + return PlanFileFormat::kProtoText; + } + if (std::regex_search(content.begin(), content.end(), kIsText)) { + return PlanFileFormat::kText; + } + return PlanFileFormat::kBinary; +} + +} // namespace + +absl::StatusOr<::substrait::proto::Plan> loadPlan( + std::string_view input_filename) { + auto contentOrError = textplan::readFromFile(input_filename.data()); + if (!contentOrError.ok()) { + return contentOrError.status(); + } + + auto encoding = detectFormat(*contentOrError); + absl::StatusOr<::substrait::proto::Plan> planOrError; + switch (encoding) { + case PlanFileFormat::kBinary: + return textplan::loadFromBinary(*contentOrError); + case PlanFileFormat::kJson: + return textplan::loadFromJson(*contentOrError); + case PlanFileFormat::kProtoText: + return textplan::loadFromProtoText(*contentOrError); + case PlanFileFormat::kText: + return textplan::loadFromText(*contentOrError); + } +} + +absl::Status savePlan( + const ::substrait::proto::Plan& plan, + std::string_view output_filename, + PlanFileFormat format) { + switch (format) { + case PlanFileFormat::kBinary: + return textplan::savePlanToBinary(plan, output_filename); + case PlanFileFormat::kJson: + return textplan::savePlanToJson(plan, output_filename); + case PlanFileFormat::kProtoText: + return textplan::savePlanToProtoText(plan, output_filename); + case PlanFileFormat::kText: + return textplan::savePlanToText(plan, output_filename); + } + return absl::UnimplementedError("Unexpected format requested."); +} + +} // namespace io::substrait diff --git a/src/substrait/common/tests/CMakeLists.txt b/src/substrait/common/tests/CMakeLists.txt index 36c5f160..88b63b8f 100644 --- a/src/substrait/common/tests/CMakeLists.txt +++ b/src/substrait/common/tests/CMakeLists.txt @@ -9,3 +9,13 @@ add_test_case( substrait_common gtest gtest_main) + +add_test_case( + substrait_io_test + SOURCES + IoTest.cpp + EXTRA_LINK_LIBS + substrait_io + protobuf-matchers + gtest + gtest_main) diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp new file mode 100644 index 00000000..183594bd --- /dev/null +++ b/src/substrait/common/tests/IoTest.cpp @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/common/Io.h" + +#include + +#include +#include +#include +#include + +using ::protobuf_matchers::EqualsProto; +using ::protobuf_matchers::Partially; + +namespace io::substrait { + +namespace { + +constexpr const char* planFileEncodingToString(PlanFileFormat e) noexcept { + switch (e) { + case PlanFileFormat::kBinary: + return "kBinary"; + case PlanFileFormat::kJson: + return "kJson"; + case PlanFileFormat::kProtoText: + return "kProtoText"; + case PlanFileFormat::kText: + return "kText"; + } + return "IMPOSSIBLE"; +} + +} // namespace + +class IoTest : public ::testing::Test {}; + +TEST_F(IoTest, LoadMissingFile) { + auto result = ::io::substrait::loadPlan("non-existent-file"); + ASSERT_FALSE(result.ok()); + ASSERT_THAT( + result.status().message(), + ::testing::ContainsRegex("Failed to open file non-existent-file")); +} + +class SaveAndLoadTestFixture : public ::testing::TestWithParam { + public: + void SetUp() override { + testFileDirectory_ = std::filesystem::temp_directory_path() / + std::filesystem::path("my_temp_dir"); + + if (!std::filesystem::create_directory(testFileDirectory_)) { + ASSERT_TRUE(false) << "Failed to create temporary directory."; + testFileDirectory_.clear(); + } + } + + void TearDown() override { + if (!testFileDirectory_.empty()) { + std::error_code err; + std::filesystem::remove_all(testFileDirectory_, err); + ASSERT_FALSE(err) << err.message(); + } + } + + static std::string makeTempFileName() { + static int tempFileNum = 0; + return "testfile" + std::to_string(++tempFileNum); + } + + protected: + std::string testFileDirectory_; +}; + +TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { + auto tempFilename = testFileDirectory_ + "/" + makeTempFileName(); + PlanFileFormat encoding = GetParam(); + + ::substrait::proto::Plan plan; + auto root = plan.add_relations()->mutable_root(); + auto read = root->mutable_input()->mutable_read(); + read->mutable_common()->mutable_direct(); + read->mutable_named_table()->add_names("table_name"); + auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); + ASSERT_TRUE(status.ok()) << "Save failed.\n" << status; + + auto result = ::io::substrait::loadPlan(tempFilename); + ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status(); + ASSERT_THAT( + *result, + Partially(EqualsProto<::substrait::proto::Plan>( + R"(relations { + root { + input { + read { + common { + direct { + } + } + named_table { + names: "table_name" + } + } + } + } + })"))); +} + +INSTANTIATE_TEST_SUITE_P( + SaveAndLoadTests, + SaveAndLoadTestFixture, + testing::Values( + PlanFileFormat::kBinary, + PlanFileFormat::kJson, + PlanFileFormat::kProtoText, + PlanFileFormat::kText), + [](const testing::TestParamInfo& info) { + return planFileEncodingToString(info.param); + }); + +} // namespace io::substrait diff --git a/src/substrait/textplan/CMakeLists.txt b/src/substrait/textplan/CMakeLists.txt index 1e8b7dbb..eda37d58 100644 --- a/src/substrait/textplan/CMakeLists.txt +++ b/src/substrait/textplan/CMakeLists.txt @@ -20,10 +20,10 @@ add_library(error_listener SubstraitErrorListener.cpp SubstraitErrorListener.h) add_library(parse_result ParseResult.cpp ParseResult.h) -add_dependencies(symbol_table substrait_proto substrait_common +add_dependencies(symbol_table substrait_proto substrait_common absl::strings fmt::fmt-header-only) -target_link_libraries(symbol_table fmt::fmt-header-only +target_link_libraries(symbol_table fmt::fmt-header-only absl::strings substrait_textplan_converter) # Provide access to the generated protobuffer headers hierarchy. diff --git a/src/substrait/textplan/StringManipulation.cpp b/src/substrait/textplan/StringManipulation.cpp index eac3c56a..cb11e53a 100644 --- a/src/substrait/textplan/StringManipulation.cpp +++ b/src/substrait/textplan/StringManipulation.cpp @@ -2,15 +2,18 @@ #include "StringManipulation.h" +#include +#include +#include +#include + namespace io::substrait::textplan { -// Yields true if the string 'haystack' starts with the string 'needle'. bool startsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(0, needle.size()) == needle; } -// Returns true if the string 'haystack' ends with the string 'needle'. bool endsWith(std::string_view haystack, std::string_view needle) { return haystack.size() > needle.size() && haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; diff --git a/src/substrait/textplan/StringManipulation.h b/src/substrait/textplan/StringManipulation.h index 9c24418f..8edf7ea5 100644 --- a/src/substrait/textplan/StringManipulation.h +++ b/src/substrait/textplan/StringManipulation.h @@ -2,7 +2,9 @@ #pragma once +#include #include +#include namespace io::substrait::textplan { diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 19f96188..1b460624 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) { auto subtype = ANY_CAST(SourceType, info.subtype); switch (subtype) { case SourceType::kNamedTable: { - auto table = - ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob); + if (info.blob.has_value()) { + // We are using the proto as is in lieu of a disciplined structure. + auto table = ANY_CAST( + const ::substrait::proto::ReadRel_NamedTable*, info.blob); + text << "source named_table " << info.name << " {\n"; + text << " names = [\n"; + for (const auto& name : table->names()) { + text << " \"" << name << "\",\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + break; + } + // We are using the new style data structure. text << "source named_table " << info.name << " {\n"; text << " names = [\n"; - for (const auto& name : table->names()) { - text << " \"" << name << "\",\n"; + for (const auto& sym : + symbolTable.lookupSymbolsByLocation(info.location)) { + if (sym->type == SymbolType::kSourceDetail) { + text << " \"" << sym->name << "\",\n"; + } } text << " ]\n"; text << "}\n"; hasPreviousText = true; + break; } case SourceType::kLocalFiles: { diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp index a5473641..b39b3464 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -411,6 +411,21 @@ std::any BasePlanProtoVisitor::visitWindowFunction( return std::nullopt; } +std::any BasePlanProtoVisitor::visitWindowRelFunction( + const ::substrait::proto::ConsistentPartitionWindowRel::WindowRelFunction& + function) { + for (const auto& arg : function.arguments()) { + visitFunctionArgument(arg); + } + for (const auto& arg : function.options()) { + visitFunctionOption(arg); + } + if (function.has_output_type()) { + visitType(function.output_type()); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen) { for (const auto& ifThenIf : ifthen.ifs()) { @@ -649,7 +664,6 @@ std::any BasePlanProtoVisitor::visitExpression( case ::substrait::proto::Expression::RexTypeCase::REX_TYPE_NOT_SET: break; } - // TODO -- Use an error listener instead. SUBSTRAIT_UNSUPPORTED( "Unsupported expression type encountered: " + std::to_string(expression.rex_type_case())); @@ -736,6 +750,25 @@ std::any BasePlanProtoVisitor::visitFieldReference( return std::nullopt; } +std::any BasePlanProtoVisitor::visitExpandField( + const ::substrait::proto::ExpandRel::ExpandField& field) { + switch (field.field_type_case()) { + case ::substrait::proto::ExpandRel_ExpandField::kSwitchingField: + for (const auto& switchingField : field.switching_field().duplicates()) { + visitExpression(switchingField); + } + break; + case ::substrait::proto::ExpandRel_ExpandField::kConsistentField: + if (field.has_consistent_field()) { + visitExpression(field.consistent_field()); + } + break; + case ::substrait::proto::ExpandRel_ExpandField::FIELD_TYPE_NOT_SET: + break; + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitReadRelation( const ::substrait::proto::ReadRel& relation) { if (relation.has_common()) { @@ -992,6 +1025,57 @@ std::any BasePlanProtoVisitor::visitMergeJoinRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitWindowRelation( + const ::substrait::proto::ConsistentPartitionWindowRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& func : relation.window_functions()) { + visitWindowRelFunction(func); + } + for (const auto& exp : relation.partition_expressions()) { + visitExpression(exp); + } + for (const auto& sort : relation.sorts()) { + visitSortField(sort); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExchangeRelation( + const ::substrait::proto::ExchangeRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExpandRelation( + const ::substrait::proto::ExpandRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& expandField : relation.fields()) { + visitExpandField(expandField); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitRelation( const ::substrait::proto::Rel& relation) { switch (relation.rel_type_case()) { @@ -1023,6 +1107,12 @@ std::any BasePlanProtoVisitor::visitRelation( return visitHashJoinRelation(relation.hash_join()); case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: return visitMergeJoinRelation(relation.merge_join()); + case ::substrait::proto::Rel::kWindow: + return visitWindowRelation(relation.window()); + case ::substrait::proto::Rel::kExchange: + return visitExchangeRelation(relation.exchange()); + case ::substrait::proto::Rel::kExpand: + return visitExpandRelation(relation.expand()); case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h index af1fb138..0eef3e00 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.h +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -4,6 +4,7 @@ #include +#include "substrait/proto/algebra.pb.h" #include "substrait/proto/plan.pb.h" namespace io::substrait::textplan { @@ -84,6 +85,9 @@ class BasePlanProtoVisitor { const ::substrait::proto::Expression::ScalarFunction& function); virtual std::any visitWindowFunction( const ::substrait::proto::Expression::WindowFunction& function); + virtual std::any visitWindowRelFunction( + const ::substrait::proto::ConsistentPartitionWindowRel::WindowRelFunction& + function); virtual std::any visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen); virtual std::any visitSwitchExpression( @@ -140,6 +144,8 @@ class BasePlanProtoVisitor { virtual std::any visitSortField(const ::substrait::proto::SortField& sort); virtual std::any visitFieldReference( const ::substrait::proto::Expression::FieldReference& ref); + virtual std::any visitExpandField( + const ::substrait::proto::ExpandRel::ExpandField& field); virtual std::any visitReadRelation( const ::substrait::proto::ReadRel& relation); @@ -168,6 +174,12 @@ class BasePlanProtoVisitor { const ::substrait::proto::HashJoinRel& relation); virtual std::any visitMergeJoinRelation( const ::substrait::proto::MergeJoinRel& relation); + virtual std::any visitWindowRelation( + const ::substrait::proto::ConsistentPartitionWindowRel& relation); + virtual std::any visitExchangeRelation( + const ::substrait::proto::ExchangeRel& relation); + virtual std::any visitExpandRelation( + const ::substrait::proto::ExpandRel& relation); virtual std::any visitRelation(const ::substrait::proto::Rel& relation); virtual std::any visitRelationRoot( diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 41f3b9ce..5ee27539 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +include(../../../../third_party/datetime.cmake) + set(TEXTPLAN_SRCS InitialPlanProtoVisitor.cpp InitialPlanProtoVisitor.h @@ -11,6 +13,8 @@ set(TEXTPLAN_SRCS PlanPrinterVisitor.h LoadBinary.cpp LoadBinary.h + SaveBinary.cpp + SaveBinary.h ParseBinary.cpp ParseBinary.h) @@ -20,10 +24,14 @@ target_link_libraries( substrait_textplan_converter substrait_common substrait_expression + substrait_io substrait_proto symbol_table error_listener - date::date) + date::date + fmt::fmt-header-only + absl::status + absl::statusor) if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index dfbdf38f..7fd3eb0a 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -76,6 +76,15 @@ void eraseInputs(::substrait::proto::Rel* relation) { relation->mutable_merge_join()->clear_left(); relation->mutable_merge_join()->clear_right(); break; + case ::substrait::proto::Rel::kWindow: + relation->mutable_window()->clear_input(); + break; + case ::substrait::proto::Rel::kExchange: + relation->mutable_exchange()->clear_input(); + break; + case ::substrait::proto::Rel::kExpand: + relation->mutable_expand()->clear_input(); + break; case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } @@ -112,6 +121,12 @@ ::google::protobuf::RepeatedField getOutputMapping( return relation.hash_join().common().emit().output_mapping(); case ::substrait::proto::Rel::kMergeJoin: return relation.merge_join().common().emit().output_mapping(); + case ::substrait::proto::Rel::kWindow: + return relation.window().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExchange: + return relation.exchange().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExpand: + return relation.expand().common().emit().output_mapping(); case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } @@ -521,6 +536,15 @@ void InitialPlanProtoVisitor::updateLocalSchema( relation.merge_join().left(), relation.merge_join().right()); break; + case ::substrait::proto::Rel::kWindow: + addFieldsToRelation(relationData, relation.window().input()); + break; + case ::substrait::proto::Rel::kExchange: + addFieldsToRelation(relationData, relation.exchange().input()); + break; + case ::substrait::proto::Rel::kExpand: + addFieldsToRelation(relationData, relation.expand().input()); + break; case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/LoadBinary.cpp b/src/substrait/textplan/converter/LoadBinary.cpp index 787ab054..c5d9f4ce 100644 --- a/src/substrait/textplan/converter/LoadBinary.cpp +++ b/src/substrait/textplan/converter/LoadBinary.cpp @@ -2,20 +2,20 @@ #include "substrait/textplan/converter/LoadBinary.h" +#include +#include #include #include #include - #include #include -#include #include #include #include #include -#include "substrait/common/Exceptions.h" #include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" namespace io::substrait::textplan { @@ -39,24 +39,23 @@ class StringErrorCollector : public google::protobuf::io::ErrorCollector { } // namespace -std::string readFromFile(std::string_view msgPath) { +absl::StatusOr readFromFile(std::string_view msgPath) { std::ifstream textFile(std::string{msgPath}); if (textFile.fail()) { - auto currdir = std::filesystem::current_path().string(); - SUBSTRAIT_FAIL( - "Failed to open file {} when running in {}: {}", - msgPath, - currdir, - strerror(errno)); + auto currDir = std::filesystem::current_path().string(); + return absl::ErrnoToStatus( + errno, + fmt::format( + "Failed to open file {} when running in {}", msgPath, currDir)); } std::stringstream buffer; buffer << textFile.rdbuf(); return buffer.str(); } -PlanOrErrors loadFromJson(std::string_view json) { +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json) { if (json.empty()) { - return PlanOrErrors({"Provided JSON string was empty."}); + return absl::InternalError("Provided JSON string was empty."); } std::string_view usableJson = json; if (json[0] == '#') { @@ -71,21 +70,32 @@ PlanOrErrors loadFromJson(std::string_view json) { std::string{usableJson}, &plan); if (!status.ok()) { std::string msg{status.message()}; - return PlanOrErrors( - {fmt::format("Failed to parse Substrait JSON: {}", msg)}); + return absl::InternalError( + fmt::format("Failed to parse Substrait JSON: {}", msg)); } - return PlanOrErrors(plan); + return plan; } -PlanOrErrors loadFromText(const std::string& text) { +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text) { ::substrait::proto::Plan plan; ::google::protobuf::TextFormat::Parser parser; StringErrorCollector collector; parser.RecordErrorsTo(&collector); if (!parser.ParseFromString(text, &plan)) { - return PlanOrErrors(collector.getErrors()); + auto errors = collector.getErrors(); + return absl::InternalError(absl::StrJoin(errors, "")); + } + return plan; +} + +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes) { + ::substrait::proto::Plan plan; + if (!plan.ParseFromString(bytes)) { + return absl::InternalError("Failed to parse as a binary Substrait plan."); } - return PlanOrErrors(plan); + return plan; } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/LoadBinary.h b/src/substrait/textplan/converter/LoadBinary.h index b0b73a90..a4d4e38f 100644 --- a/src/substrait/textplan/converter/LoadBinary.h +++ b/src/substrait/textplan/converter/LoadBinary.h @@ -2,12 +2,9 @@ #pragma once +#include #include #include -#include -#include - -#include "substrait/proto/plan.pb.h" namespace substrait::proto { class Plan; @@ -15,41 +12,21 @@ class Plan; namespace io::substrait::textplan { -// PlanOrErrors behaves similarly to abseil::StatusOr. -class PlanOrErrors { - public: - explicit PlanOrErrors(::substrait::proto::Plan plan) - : plan_(std::move(plan)){}; - explicit PlanOrErrors(std::vector errors) - : errors_(std::move(errors)){}; - - bool ok() { - return errors_.empty(); - } - - const ::substrait::proto::Plan& operator*() { - return plan_; - } - - const std::vector& errors() { - return errors_; - } - - private: - ::substrait::proto::Plan plan_; - std::vector errors_; -}; - // Read the contents of a file from disk. -// Throws an exception if file cannot be read. -std::string readFromFile(std::string_view msgPath); +absl::StatusOr readFromFile(std::string_view msgPath); // Reads a plan from a json-encoded text proto. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromJson(std::string_view json); +absl::StatusOr<::substrait::proto::Plan> loadFromJson(const std::string& json); // Reads a plan encoded as a text protobuf. // Returns a list of errors if the file cannot be parsed. -PlanOrErrors loadFromText(const std::string& text); +absl::StatusOr<::substrait::proto::Plan> loadFromProtoText( + const std::string& text); + +// Reads a plan serialized as a binary protobuf. +// Returns a list of errors if the file cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromBinary( + const std::string& bytes); } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index d63a6f56..57d5dcea 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -107,6 +107,24 @@ std::any PipelineVisitor::visitRelation( relationData->newPipelines.push_back(rightSymbol); break; } + case ::substrait::proto::Rel::kWindow: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.window().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kExchange: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.exchange().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kExpand: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.expand().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/README.md b/src/substrait/textplan/converter/README.md index 607c0652..53a0e98e 100644 --- a/src/substrait/textplan/converter/README.md +++ b/src/substrait/textplan/converter/README.md @@ -1,7 +1,7 @@ # Using the Plan Converter Tool -The plan converter takes any number of JSON encoded Substrait plan files and converts them into the Substrait Text Plan -format. +The plan converter takes any number of Substrait plan files of any format and +converts them into the Substrait Text Plan format. ## Usage: ``` diff --git a/src/substrait/textplan/converter/SaveBinary.cpp b/src/substrait/textplan/converter/SaveBinary.cpp new file mode 100644 index 00000000..c8dd6c07 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.cpp @@ -0,0 +1,113 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/SaveBinary.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/converter/ParseBinary.h" + +namespace io::substrait::textplan { + +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!plan.SerializeToZeroCopyStream(stream)) { + return ::absl::UnknownError("Failed to write plan to stream."); + } + + if (!stream->Close()) { + return absl::AbortedError("Failed to close file descriptor."); + } + delete stream; + return absl::OkStatus(); +} + +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + std::string output; + auto status = ::google::protobuf::util::MessageToJsonString(plan, &output); + if (!status.ok()) { + return absl::UnknownError("Failed to save plan as a JSON protobuf."); + } + stream << output; + stream.close(); + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as a JSON protobuf."); + } + return absl::OkStatus(); +} + +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + std::ofstream stream(std::string{output_filename}); + if ((stream.fail())) { + return absl::UnavailableError( + fmt::format("Failed to open file {} for writing", output_filename)); + } + + auto result = parseBinaryPlan(plan); + auto errors = result.getAllErrors(); + if (!errors.empty()) { + return absl::UnknownError(absl::StrJoin(errors, "")); + } + stream << SymbolTablePrinter::outputToText(result.getSymbolTable()); + stream.close(); + if (stream.fail()) { + return absl::UnknownError("Failed to write the plan as text."); + } + return absl::OkStatus(); +} + +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename) { + int outputFileDescriptor = + creat(std::string{output_filename}.c_str(), S_IREAD | S_IWRITE); + if (outputFileDescriptor == -1) { + return absl::ErrnoToStatus( + errno, + fmt::format("Failed to open file {} for writing", output_filename)); + } + auto stream = + new google::protobuf::io::FileOutputStream(outputFileDescriptor); + + if (!::google::protobuf::TextFormat::Print(plan, stream)) { + return absl::UnknownError("Failed to save plan as a text protobuf."); + } + + if (!stream->Close()) { + return absl::AbortedError("Failed to close file descriptor."); + } + delete stream; + return absl::OkStatus(); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/SaveBinary.h b/src/substrait/textplan/converter/SaveBinary.h new file mode 100644 index 00000000..d9158773 --- /dev/null +++ b/src/substrait/textplan/converter/SaveBinary.h @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "absl/status/status.h" + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Serializes a plan to disk as a binary protobuf. +absl::Status savePlanToBinary( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a JSON-encoded protobuf. +absl::Status savePlanToJson( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Calls the converter to store a plan on disk as a text-based substrait plan. +absl::Status savePlanToText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +// Serializes a plan to disk as a text-encoded protobuf. +absl::Status savePlanToProtoText( + const ::substrait::proto::Plan& plan, + std::string_view output_filename); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp index 9c3d652b..bdcc707a 100644 --- a/src/substrait/textplan/converter/Tool.cpp +++ b/src/substrait/textplan/converter/Tool.cpp @@ -6,6 +6,7 @@ #include +#include "substrait/common/Io.h" #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" @@ -13,14 +14,10 @@ namespace io::substrait::textplan { namespace { -void convertJsonToText(const char* filename) { - std::string json = readFromFile(filename); - auto planOrError = loadFromJson(json); +void convertPlanToText(const char* filename) { + auto planOrError = loadPlan(filename); if (!planOrError.ok()) { - std::cerr << "An error occurred while reading: " << filename << std::endl; - for (const auto& err : planOrError.errors()) { - std::cerr << err << std::endl; - } + std::cerr << planOrError.status() << std::endl; return; } @@ -46,7 +43,7 @@ int main(int argc, char* argv[]) { #ifdef _WIN32 for (int currArg = 1; currArg < argc; currArg++) { printf("===== %s =====\n", argv[currArg]); - io::substrait::textplan::convertJsonToText(argv[currArg]); + io::substrait::textplan::convertPlanToText(argv[currArg]); } #else for (int currArg = 1; currArg < argc; currArg++) { @@ -54,7 +51,7 @@ int main(int argc, char* argv[]) { glob(argv[currArg], GLOB_TILDE, nullptr, &globResult); for (size_t i = 0; i < globResult.gl_pathc; i++) { printf("===== %s =====\n", globResult.gl_pathv[i]); - io::substrait::textplan::convertJsonToText(globResult.gl_pathv[i]); + io::substrait::textplan::convertPlanToText(globResult.gl_pathv[i]); } } #endif diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index cde3c92a..3b0ae97f 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -597,9 +597,10 @@ std::vector getTestCases() { TEST_P(BinaryToTextPlanConverterTestFixture, Parse) { auto [name, input, matcher] = GetParam(); - auto planOrError = loadFromText(input); + auto planOrError = loadFromProtoText(input); if (!planOrError.ok()) { - ParseResult result(SymbolTable(), planOrError.errors(), {}); + ParseResult result( + SymbolTable(), {std::string(planOrError.status().message())}, {}); ASSERT_THAT(result, matcher); return; } @@ -627,13 +628,15 @@ INSTANTIATE_TEST_SUITE_P( class BinaryToTextPlanConversionTest : public ::testing::Test {}; TEST_F(BinaryToTextPlanConversionTest, FullSample) { - std::string json = readFromFile("data/q6_first_stage.json"); - auto planOrError = loadFromJson(json); + auto jsonOrError = readFromFile("data/q6_first_stage.json"); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); ASSERT_TRUE(planOrError.ok()); auto plan = *planOrError; EXPECT_THAT(plan.extensions_size(), ::testing::Eq(7)); - std::string expectedOutput = readFromFile("data/q6_first_stage.golden.splan"); + auto expectedOutputOrError = readFromFile("data/q6_first_stage.golden.splan"); + ASSERT_TRUE(expectedOutputOrError.ok()); auto result = parseBinaryPlan(plan); auto symbols = result.getSymbolTable().getSymbols(); @@ -668,7 +671,7 @@ TEST_F(BinaryToTextPlanConversionTest, FullSample) { SymbolType::kSource, SymbolType::kSchema, }), - WhenSerialized(EqSquashingWhitespace(expectedOutput)))) + WhenSerialized(EqSquashingWhitespace(*expectedOutputOrError)))) << result.getSymbolTable().toDebugString(); } diff --git a/src/substrait/textplan/parser/CMakeLists.txt b/src/substrait/textplan/parser/CMakeLists.txt index 498d572c..fcc01413 100644 --- a/src/substrait/textplan/parser/CMakeLists.txt +++ b/src/substrait/textplan/parser/CMakeLists.txt @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +include(../../../../third_party/datetime.cmake) + add_subdirectory(grammar) add_library( @@ -12,6 +14,8 @@ add_library( SubstraitPlanRelationVisitor.h SubstraitPlanTypeVisitor.cpp SubstraitPlanTypeVisitor.h + LoadText.cpp + LoadText.h ParseText.cpp ParseText.h SubstraitParserErrorListener.cpp) @@ -26,7 +30,9 @@ target_link_libraries( textplan_grammar fmt::fmt-header-only date::date - date::date-tz) + date::date-tz + absl::status + absl::statusor) add_executable(planparser Tool.cpp) diff --git a/src/substrait/textplan/parser/LoadText.cpp b/src/substrait/textplan/parser/LoadText.cpp new file mode 100644 index 00000000..c76a7b31 --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.cpp @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/parser/LoadText.h" + +#include + +#include "substrait/proto/plan.pb.h" +#include "substrait/textplan/StringManipulation.h" +#include "substrait/textplan/SymbolTablePrinter.h" +#include "substrait/textplan/parser/ParseText.h" + +namespace io::substrait::textplan { + +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text) { + auto stream = loadTextString(text); + auto parseResult = io::substrait::textplan::parseStream(stream); + if (!parseResult.successful()) { + auto errors = parseResult.getAllErrors(); + return absl::UnknownError(absl::StrJoin(errors, "")); + } + + return SymbolTablePrinter::outputToBinaryPlan(parseResult.getSymbolTable()); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/LoadText.h b/src/substrait/textplan/parser/LoadText.h new file mode 100644 index 00000000..3b2f900b --- /dev/null +++ b/src/substrait/textplan/parser/LoadText.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +namespace substrait::proto { +class Plan; +} + +namespace io::substrait::textplan { + +// Reads a plan encoded as a text protobuf. +// Returns a list of errors if the text cannot be parsed. +absl::StatusOr<::substrait::proto::Plan> loadFromText(const std::string& text); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 3a19d3cc..a4ba32fc 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: Apache-2.0 */ -#include "ParseText.h" +#include "substrait/textplan/parser/ParseText.h" #include #include diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index a8273f12..4a6d5567 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -974,14 +974,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( case SourceType::kNamedTable: { auto* source = parentRelationData->relation.mutable_read()->mutable_named_table(); - for (const auto& sym : *symbolTable_) { - if (sym.type != SymbolType::kSourceDetail) { - continue; - } - if (sym.location != symbol->location) { + for (const auto& sym : + symbolTable_->lookupSymbolsByLocation(symbol->location)) { + if (sym->type != SymbolType::kSourceDetail) { continue; } - source->add_names(sym.name); + source->add_names(sym->name); } break; } diff --git a/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp index 403caa22..5b42d817 100644 --- a/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp @@ -109,11 +109,15 @@ ::substrait::proto::Type SubstraitPlanTypeVisitor::typeToProto( } case TypeKind::kVarchar: { auto varChar = - reinterpret_cast(&decodedType); + reinterpret_cast(&decodedType); if (varChar == nullptr) { break; } try { + if (!varChar->length()->isInteger()) { + errorListener_->addError(ctx->getStart(), "Missing varchar length."); + break; + } int32_t length = std::stoi(varChar->length()->value()); type.mutable_varchar()->set_length(length); } catch (...) { diff --git a/src/substrait/textplan/parser/grammar/CMakeLists.txt b/src/substrait/textplan/parser/grammar/CMakeLists.txt index 02430e0c..e04ea81a 100644 --- a/src/substrait/textplan/parser/grammar/CMakeLists.txt +++ b/src/substrait/textplan/parser/grammar/CMakeLists.txt @@ -11,15 +11,15 @@ set(GRAMMAR_DIR ${CMAKE_BINARY_DIR}/src/substrait/textplan/parser/grammar) # using /MD flag for antlr4_runtime (for Visual C++ compilers only) set(ANTLR4_WITH_STATIC_CRT OFF) -set(ANTLR4_TAG 4.12.0) +set(ANTLR4_TAG 4.13.0) set(ANTLR4_ZIP_REPOSITORY - https://github.com/antlr/antlr4/archive/refs/tags/4.12.0.zip) + https://github.com/antlr/antlr4/archive/refs/tags/${ANTLR4_TAG}.zip) include(ExternalAntlr4Cpp) include_directories(${ANTLR4_INCLUDE_DIRS}) -file(DOWNLOAD https://www.antlr.org/download/antlr-4.12.0-complete.jar +file(DOWNLOAD https://www.antlr.org/download/antlr-4.13.0-complete.jar "${GRAMMAR_DIR}/antlr.jar") set(ANTLR_EXECUTABLE "${GRAMMAR_DIR}/antlr.jar") find_package(ANTLR REQUIRED) diff --git a/src/substrait/textplan/tests/CMakeLists.txt b/src/substrait/textplan/tests/CMakeLists.txt index 24be6792..116ec4e7 100644 --- a/src/substrait/textplan/tests/CMakeLists.txt +++ b/src/substrait/textplan/tests/CMakeLists.txt @@ -19,50 +19,38 @@ add_test_case( gtest gtest_main) -option(SUBSTRAIT_CPP_ROUNDTRIP_TESTING - "Enable substrait-cpp textplan roundtrip tests." OFF) - -if(${SUBSTRAIT_CPP_ROUNDTRIP_TESTING}) - add_test_case( - round_trip_test - SOURCES - RoundtripTest.cpp - EXTRA_LINK_LIBS - substrait_textplan_converter - substrait_textplan_loader - substrait_textplan_normalizer - substrait_common - substrait_proto - parse_result_matchers - protobuf-matchers - fmt::fmt-header-only - gmock - gtest - gtest_main) - - cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH TEXTPLAN_SOURCE_DIR) - - add_custom_command( - TARGET round_trip_test - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E echo "Copying unit test data.." - COMMAND ${CMAKE_COMMAND} -E make_directory - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" - COMMAND - ${CMAKE_COMMAND} -E copy - "${TEXTPLAN_SOURCE_DIR}/converter/data/q6_first_stage.json" - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" - COMMAND ${CMAKE_COMMAND} -E copy "${TEXTPLAN_SOURCE_DIR}/data/*.json" - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/") - - message( - STATUS - "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data") -else() - - message( - STATUS - "Round trip testing is turned off. Add SUBSTRAIT_CPP_ROUNDTRIP_TESTING=on to enable." - ) +add_test_case( + substrait_textplan_round_trip_test + SOURCES + RoundtripTest.cpp + EXTRA_LINK_LIBS + substrait_textplan_converter + substrait_textplan_loader + substrait_textplan_normalizer + substrait_common + substrait_proto + parse_result_matchers + protobuf-matchers + fmt::fmt-header-only + gmock + gtest + gtest_main) -endif() +cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH TEXTPLAN_SOURCE_DIR) + +add_custom_command( + TARGET substrait_textplan_round_trip_test + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E echo "Copying unit test data.." + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" + COMMAND + ${CMAKE_COMMAND} -E copy + "${TEXTPLAN_SOURCE_DIR}/converter/data/q6_first_stage.json" + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" + COMMAND ${CMAKE_COMMAND} -E copy "${TEXTPLAN_SOURCE_DIR}/data/*.json" + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/") + +message( + STATUS "test data will be here: ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" +) diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 3b22071c..42f447c0 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -19,8 +19,7 @@ #include "substrait/textplan/tests/ParseResultMatchers.h" using ::protobuf_matchers::EqualsProto; -using ::protobuf_matchers::IgnoringFieldPaths; -using ::protobuf_matchers::Partially; +using ::protobuf_matchers::IgnoringFields; using ::testing::AllOf; namespace io::substrait::textplan { @@ -70,12 +69,12 @@ std::vector getTestCases() { TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { auto filename = GetParam(); - std::string json = readFromFile(filename); - auto planOrErrors = loadFromJson(json); - std::vector errors = planOrErrors.errors(); - ASSERT_THAT(errors, ::testing::ElementsAre()); + auto jsonOrError = readFromFile(filename); + ASSERT_TRUE(jsonOrError.ok()); + auto planOrError = loadFromJson(*jsonOrError); + ASSERT_TRUE(planOrError.ok()); - auto plan = *planOrErrors; + auto plan = *planOrError; auto textResult = parseBinaryPlan(plan); auto textSymbols = textResult.getSymbolTable().getSymbols(); @@ -98,7 +97,11 @@ TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { ASSERT_THAT( result, ::testing::AllOf( - ParsesOk(), HasErrors({}), AsBinaryPlan(EqualsProto(normalizedPlan)))) + ParsesOk(), + HasErrors({}), + AsBinaryPlan(IgnoringFields( + {"substrait.proto.RelCommon.Emit.output_mapping"}, + EqualsProto(normalizedPlan))))) << std::endl << "Intermediate result:" << std::endl << addLineNumbers(outputText) << std::endl @@ -115,8 +118,7 @@ INSTANTIATE_TEST_SUITE_P( if (lastSlash != std::string::npos) { identifier = identifier.substr(lastSlash); } - if (identifier.length() > 5 && - identifier.substr(identifier.length() - 5) == ".json") { + if (endsWith(identifier, ".json")) { identifier = identifier.substr(0, identifier.length() - 5); } diff --git a/src/substrait/type/Type.cpp b/src/substrait/type/Type.cpp index 1fa21115..5a0e3cff 100644 --- a/src/substrait/type/Type.cpp +++ b/src/substrait/type/Type.cpp @@ -197,7 +197,12 @@ ParameterizedTypePtr ParameterizedType::decode( const auto& leftAngleBracketPos = matchingType.find('<'); if (leftAngleBracketPos == std::string::npos) { - bool nullable = matchingType.back() == '?'; + bool nullable; + if (matchingType.empty()) { + nullable = false; + } else { + nullable = matchingType.back() == '?'; + } // deal with type and with a question mask like "i32?". const auto& baseType = nullable ? matchingType = matchingType.substr(0, questionMaskPos) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7b8d8e60..2124c76d 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -5,8 +5,7 @@ if(NOT ${ABSL_INCLUDED_WITH_PROTOBUF}) add_subdirectory(abseil-cpp) endif() -set(BUILD_TZ_LIB ON) -add_subdirectory(datetime) +include(datetime.cmake) add_subdirectory(fmt) add_subdirectory(googletest) diff --git a/third_party/datetime b/third_party/datetime deleted file mode 160000 index cc4685a2..00000000 --- a/third_party/datetime +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cc4685a21e4a4fdae707ad1233c61bbaff241f93 diff --git a/third_party/datetime.cmake b/third_party/datetime.cmake new file mode 100644 index 00000000..f96aa06c --- /dev/null +++ b/third_party/datetime.cmake @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +include_guard(GLOBAL) + +set (BUILD_TZ_LIB ON CACHE BOOL "timezone library is a dependency" FORCE) +include(FetchContent) +FetchContent_Declare(date_src + GIT_REPOSITORY https://github.com/HowardHinnant/date.git + GIT_TAG v3.0.1 + ) +FetchContent_MakeAvailable(date_src) diff --git a/third_party/substrait b/third_party/substrait index 07e4feb5..31b99906 160000 --- a/third_party/substrait +++ b/third_party/substrait @@ -1 +1 @@ -Subproject commit 07e4feb5983478e7d0d95dc1d9b5e176685dbdc3 +Subproject commit 31b999060a6e014717f9ae3e6716986ad3066aaf