diff --git a/include/substrait/common/Io.h b/include/substrait/common/Io.h index 4227f7a3..d090cf75 100644 --- a/include/substrait/common/Io.h +++ b/include/substrait/common/Io.h @@ -16,6 +16,19 @@ enum PlanFileEncoding { kText = 3, }; +constexpr const char* PlanFileEncodingToString(PlanFileEncoding e) noexcept { + switch (e) { + case PlanFileEncoding::kBinary: + return "kBinary"; + case PlanFileEncoding::kJson: + return "kJson"; + case PlanFileEncoding::kProtoText: + return "kProtoText"; + case PlanFileEncoding::kText: + return "kText"; + } +} + // Loads a Substrait plan consisting of any encoding type from the given file. absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding( std::string_view input_filename); diff --git a/src/substrait/common/tests/IoTest.cpp b/src/substrait/common/tests/IoTest.cpp index 4916fbfa..c55b2dab 100644 --- a/src/substrait/common/tests/IoTest.cpp +++ b/src/substrait/common/tests/IoTest.cpp @@ -9,6 +9,8 @@ using ::protobuf_matchers::EqualsProto; using ::protobuf_matchers::Partially; +namespace io::substrait { + class IoTest : public ::testing::Test {}; TEST_F(IoTest, LoadMissingFile) { @@ -20,78 +22,21 @@ TEST_F(IoTest, LoadMissingFile) { ::testing::ContainsRegex("Failed to open file non-existent-file")); } -TEST_F(IoTest, SaveAndLoadBinary) { - ::substrait::proto::Plan plan; - auto root = plan.add_relations()->mutable_root(); - auto read = root->mutable_input()->mutable_read(); - read->mutable_named_table()->add_names("table_name"); - auto status = - ::io::substrait::savePlan(plan, "rwtest.plan", io::substrait::kBinary); - ASSERT_TRUE(status.ok()) << status; +class SaveAndLoadTestFixture + : public ::testing::TestWithParam {}; - auto result = ::io::substrait::loadPlanWithUnknownEncoding("rwtest.plan"); - ASSERT_TRUE(result.ok()) << result.status(); - ASSERT_THAT( - *result, - Partially(EqualsProto<::substrait::proto::Plan>( - R"(relations { - root { - input { - read { - common { - direct { - } - } - named_table { - names: "table_name" - } - } - } - } - })"))); -} - -TEST_F(IoTest, SaveAndLoadJson) { - ::substrait::proto::Plan plan; - auto root = plan.add_relations()->mutable_root(); - auto read = root->mutable_input()->mutable_read(); - read->mutable_named_table()->add_names("table_name"); - auto status = - ::io::substrait::savePlan(plan, "rwtest.json", io::substrait::kJson); - ASSERT_TRUE(status.ok()) << status; - - auto result = ::io::substrait::loadPlanWithUnknownEncoding("rwtest.json"); - ASSERT_TRUE(result.ok()) << result.status(); - ASSERT_THAT( - *result, - Partially(EqualsProto<::substrait::proto::Plan>( - R"(relations { - root { - input { - read { - common { - direct { - } - } - named_table { - names: "table_name" - } - } - } - } - })"))); -} +TEST_P(SaveAndLoadTestFixture, SaveAndLoad) { + auto tempFilename = std::tmpnam(nullptr); + PlanFileEncoding encoding = GetParam(); -TEST_F(IoTest, SaveAndLoadProtoText) { ::substrait::proto::Plan plan; auto root = plan.add_relations()->mutable_root(); auto read = root->mutable_input()->mutable_read(); read->mutable_named_table()->add_names("table_name"); - auto status = ::io::substrait::savePlan( - plan, "rwtest.protobuf", io::substrait::kProtoText); + auto status = ::io::substrait::savePlan(plan, tempFilename, encoding); ASSERT_TRUE(status.ok()) << status; - auto result = ::io::substrait::loadPlanWithUnknownEncoding("rwtest.protobuf"); + auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename); ASSERT_TRUE(result.ok()) << result.status(); ASSERT_THAT( *result, @@ -113,33 +58,12 @@ TEST_F(IoTest, SaveAndLoadProtoText) { })"))); } -TEST_F(IoTest, SaveAndLoadText) { - ::substrait::proto::Plan plan; - auto root = plan.add_relations()->mutable_root(); - auto read = root->mutable_input()->mutable_read(); - read->mutable_named_table()->add_names("table_name"); - auto status = - ::io::substrait::savePlan(plan, "rwtest.splan", io::substrait::kText); - ASSERT_TRUE(status.ok()) << status; +INSTANTIATE_TEST_SUITE_P( + SaveAndLoadTests, + SaveAndLoadTestFixture, + testing::Values(kBinary, kJson, kProtoText, kText), + [](const testing::TestParamInfo& info) { + return PlanFileEncodingToString(info.param); + }); - auto result = ::io::substrait::loadPlanWithUnknownEncoding("rwtest.splan"); - ASSERT_TRUE(result.ok()) << result.status(); - ASSERT_THAT( - *result, - Partially(EqualsProto<::substrait::proto::Plan>( - R"(relations { - root { - input { - read { - common { - direct { - } - } - named_table { - names: "table_name" - } - } - } - } - })"))); -} +} // namespace io::substrait