diff --git a/src/substrait/common/Io.cpp b/src/substrait/common/Io.cpp index 475c49f1..7055160c 100644 --- a/src/substrait/common/Io.cpp +++ b/src/substrait/common/Io.cpp @@ -2,6 +2,7 @@ #include "substrait/common/Io.h" +#include #include #include "substrait/textplan/converter/LoadBinary.h" @@ -10,35 +11,47 @@ namespace io::substrait { +namespace { + +const std::regex kIsJson(R"(("extensionUris"|"extensions"|"relations"))"); +const std::regex kIsProtoText(R"(potatoes)"); +const std::regex kIsText( + R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)"); + +PlanFileEncoding detectEncoding(std::string_view content) { + if (std::regex_search(content.begin(), content.end(), kIsJson)) { + return kJson; + } + if (std::regex_search(content.begin(), content.end(), kIsProtoText)) { + return kProtoText; + } + if (std::regex_search(content.begin(), content.end(), kIsText)) { + return kText; + } + return kBinary; +} + +} // namespace + absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( std::string_view input_filename) { - // TODO -- Add logic to detect the file type before trying to load it. auto contentOrError = textplan::readFromFile(input_filename); if (!contentOrError.ok()) { return contentOrError.status(); } - auto planOrError = textplan::loadFromJson(*contentOrError); - if (planOrError.ok()) { - return *planOrError; - } - - planOrError = textplan::loadFromProtoText(*contentOrError); - if (planOrError.ok()) { - return *planOrError; - } - - planOrError = textplan::loadFromText(*contentOrError); - if (planOrError.ok()) { - return *planOrError; - } - - planOrError = textplan::loadFromBinary(*contentOrError); - if (planOrError.ok()) { - return *planOrError; + auto encoding = detectEncoding(*contentOrError); + absl::StatusOr<::substrait::proto::Plan> planOrError; + switch (encoding) { + case kBinary: + return textplan::loadFromBinary(*contentOrError); + case kJson: + return textplan::loadFromJson(*contentOrError); + case kProtoText: + return textplan::loadFromProtoText(*contentOrError); + case kText: + return textplan::loadFromText(*contentOrError); } - - return planOrError.status(); } absl::Status savePlan( diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp index 61c7631f..eec65562 100644 --- a/src/substrait/common/tests/IoTest.cpp +++ b/src/substrait/common/tests/IoTest.cpp @@ -5,6 +5,7 @@ #include #include #include +#include using ::protobuf_matchers::EqualsProto; using ::protobuf_matchers::Partially; @@ -40,10 +41,25 @@ TEST_F(IoTest, LoadMissingFile) { } class SaveAndLoadTestFixture - : public ::testing::TestWithParam {}; + : public ::testing::TestWithParam { + public: + ~SaveAndLoadTestFixture() override { + for (const auto& filename : testFiles_) { + unlink(filename.c_str()); + } + } + + void registerCleanup(const char* filename) { + testFiles_.emplace_back(filename); + } + + private: + std::vector testFiles_; +}; TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { auto tempFilename = std::tmpnam(nullptr); + registerCleanup(tempFilename); PlanFileEncoding encoding = GetParam(); ::substrait::proto::Plan plan; @@ -51,10 +67,10 @@ TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { auto read = root->mutable_input()->mutable_read(); read->mutable_named_table()->add_names("table_name"); auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); - ASSERT_TRUE(status.ok()) << status; + ASSERT_TRUE(status.ok()) << "Save failed.\n" << status; auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename); - ASSERT_TRUE(result.ok()) << result.status(); + ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status(); ASSERT_THAT( *result, Partially(EqualsProto<::substrait::proto::Plan>(