Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add api for loading plans of all types #80

Merged
merged 11 commits into from
Aug 26, 2023
Merged
30 changes: 30 additions & 0 deletions include/substrait/common/Io.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string_view>

#include "absl/status/statusor.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait {

enum PlanFileEncoding {
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
kBinary = 0,
kJson = 1,
kProtoText = 2,
kText = 3,
};

// Loads a Substrait plan consisting of any encoding type from the given file.
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding(
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
std::string_view input_filename);

// Writes the provided plan to the specified location with the specified
// encoding type.
[[maybe_unused]] absl::Status savePlan(
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileEncoding encoding);

} // namespace io::substrait
13 changes: 12 additions & 1 deletion src/substrait/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

add_library(substrait_common Exceptions.cpp)

target_link_libraries(substrait_common fmt::fmt-header-only)

add_library(substrait_io Io.cpp)
add_dependencies(
substrait_io
substrait_proto
substrait_textplan_converter
substrait_textplan_loader
fmt::fmt-header-only
absl::status
absl::statusor)
target_link_libraries(substrait_io substrait_proto substrait_textplan_converter
substrait_textplan_loader absl::status absl::statusor)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
endif()
77 changes: 77 additions & 0 deletions src/substrait/common/Io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <regex>
#include <string_view>

#include "substrait/proto/plan.pb.h"
#include "substrait/textplan/converter/LoadBinary.h"
#include "substrait/textplan/converter/SaveBinary.h"
#include "substrait/textplan/parser/LoadText.h"

namespace io::substrait {

namespace {

const std::regex kIsJson(R"(("extensionUris"|"extensions"|"relations"))");
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
const std::regex kIsProtoText(
R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))");
const std::regex kIsText(
R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)");

PlanFileEncoding detectEncoding(std::string_view content) {
if (std::regex_search(content.begin(), content.end(), kIsJson)) {
return kJson;
}
if (std::regex_search(content.begin(), content.end(), kIsProtoText)) {
return kProtoText;
}
if (std::regex_search(content.begin(), content.end(), kIsText)) {
return kText;
}
return kBinary;
}

} // namespace

absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding(
std::string_view input_filename) {
auto contentOrError = textplan::readFromFile(input_filename.data());
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
if (!contentOrError.ok()) {
return contentOrError.status();
}

auto encoding = detectEncoding(*contentOrError);
absl::StatusOr<::substrait::proto::Plan> planOrError;
switch (encoding) {
case kBinary:
return textplan::loadFromBinary(*contentOrError);
case kJson:
return textplan::loadFromJson(*contentOrError);
case kProtoText:
return textplan::loadFromProtoText(*contentOrError);
case kText:
return textplan::loadFromText(*contentOrError);
}
return absl::UnimplementedError("Unexpected encoding requested.");
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
}

absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileEncoding encoding) {
switch (encoding) {
case kBinary:
return textplan::savePlanToBinary(plan, output_filename);
case kJson:
return textplan::savePlanToJson(plan, output_filename);
case kProtoText:
return textplan::savePlanToProtoText(plan, output_filename);
case kText:
return textplan::savePlanToText(plan, output_filename);
}
return absl::UnimplementedError("Unexpected encoding requested.");
}

} // namespace io::substrait
10 changes: 10 additions & 0 deletions src/substrait/common/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ add_test_case(
substrait_common
gtest
gtest_main)

add_test_case(
substrait_io_test
SOURCES
IoTest.cpp
EXTRA_LINK_LIBS
substrait_io
protobuf-matchers
gtest
gtest_main)
104 changes: 104 additions & 0 deletions src/substrait/common/tests/IoTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <protobuf-matchers/protocol-buffer-matchers.h>
#include <unistd.h>

using ::protobuf_matchers::EqualsProto;
using ::protobuf_matchers::Partially;

namespace io::substrait {

namespace {

constexpr const char* planFileEncodingToString(PlanFileEncoding e) noexcept {
switch (e) {
case PlanFileEncoding::kBinary:
return "kBinary";
case PlanFileEncoding::kJson:
return "kJson";
case PlanFileEncoding::kProtoText:
return "kProtoText";
case PlanFileEncoding::kText:
return "kText";
}
return "IMPOSSIBLE";
}

} // namespace

class IoTest : public ::testing::Test {};

TEST_F(IoTest, LoadMissingFile) {
auto result =
::io::substrait::loadPlanWithUnknownEncoding("non-existent-file");
ASSERT_FALSE(result.ok());
ASSERT_THAT(
result.status().message(),
::testing::ContainsRegex("Failed to open file non-existent-file"));
}

class SaveAndLoadTestFixture
: public ::testing::TestWithParam<PlanFileEncoding> {
public:
~SaveAndLoadTestFixture() override {
for (const auto& filename : testFiles_) {
unlink(filename.c_str());
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
}
}

void registerCleanup(const char* filename) {
testFiles_.emplace_back(filename);
}

private:
std::vector<std::string> testFiles_;
};

TEST_P(SaveAndLoadTestFixture, SaveAndLoad) {
auto tempFilename = std::tmpnam(nullptr);
registerCleanup(tempFilename);
PlanFileEncoding encoding = GetParam();

::substrait::proto::Plan plan;
auto root = plan.add_relations()->mutable_root();
auto read = root->mutable_input()->mutable_read();
read->mutable_common()->mutable_direct();
read->mutable_named_table()->add_names("table_name");
auto status = ::io::substrait::savePlan(plan, tempFilename, encoding);
ASSERT_TRUE(status.ok()) << "Save failed.\n" << status;

auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename);
ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status();
ASSERT_THAT(
*result,
Partially(EqualsProto<::substrait::proto::Plan>(
R"(relations {
root {
input {
read {
common {
direct {
}
}
named_table {
names: "table_name"
}
}
}
}
})")));
}

INSTANTIATE_TEST_SUITE_P(
SaveAndLoadTests,
SaveAndLoadTestFixture,
testing::Values(kBinary, kJson, kProtoText, kText),
[](const testing::TestParamInfo<SaveAndLoadTestFixture::ParamType>& info) {
return planFileEncodingToString(info.param);
});

} // namespace io::substrait
19 changes: 17 additions & 2 deletions src/substrait/textplan/StringManipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,33 @@

#include "StringManipulation.h"

#include <numeric>
#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(0, needle.size()) == needle;
}

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(haystack.size() - needle.size(), needle.size()) == needle;
}

std::string joinLines(
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::string> lines,
std::string_view separator) {
auto concatWithSeparator = [separator](std::string a, const std::string& b) {
return std::move(a) + std::string(separator) + b;
};

auto result = std::accumulate(
std::next(lines.begin()), lines.end(), lines[0], concatWithSeparator);
return result;
}

} // namespace io::substrait::textplan
7 changes: 7 additions & 0 deletions src/substrait/textplan/StringManipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#pragma once

#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

Expand All @@ -12,4 +14,9 @@ bool startsWith(std::string_view haystack, std::string_view needle);
// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle);

// Joins a vector of strings into a single string separated by separator.
std::string joinLines(
std::vector<std::string> lines,
std::string_view separator = "\n");

} // namespace io::substrait::textplan
25 changes: 21 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) {
auto subtype = ANY_CAST(SourceType, info.subtype);
switch (subtype) {
case SourceType::kNamedTable: {
auto table =
ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob);
if (info.blob.has_value()) {
// We are using the proto as is in lieu of a disciplined structure.
auto table = ANY_CAST(
const ::substrait::proto::ReadRel_NamedTable*, info.blob);
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;
break;
}
// We are using the new style data structure.
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
for (const auto& sym :
symbolTable.lookupSymbolsByLocation(info.location)) {
if (sym->type == SymbolType::kSourceDetail) {
text << " \"" << sym->name << "\",\n";
}
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;

break;
}
case SourceType::kLocalFiles: {
Expand Down
8 changes: 7 additions & 1 deletion src/substrait/textplan/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ set(TEXTPLAN_SRCS
PlanPrinterVisitor.h
LoadBinary.cpp
LoadBinary.h
SaveBinary.cpp
SaveBinary.h
ParseBinary.cpp
ParseBinary.h)

Expand All @@ -20,10 +22,14 @@ target_link_libraries(
substrait_textplan_converter
substrait_common
substrait_expression
substrait_io
substrait_proto
symbol_table
error_listener
date::date)
date::date
fmt::fmt-header-only
absl::status
absl::statusor)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
Expand Down
Loading