Skip to content

Commit

Permalink
feat: add api for loading plans of all types
Browse files Browse the repository at this point in the history
features:
    * supports reading/writing all of the major formats

caveats:
    * only read/writes to filenames so in-memory use should use other interfaces
    * does not support compression
    * does not have a zero copy interface
  • Loading branch information
EpsilonPrime committed Jul 29, 2023
1 parent 0a71a2f commit 2dfee3b
Show file tree
Hide file tree
Showing 22 changed files with 536 additions and 89 deletions.
30 changes: 30 additions & 0 deletions include/substrait/common/Io.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string_view>

#include "absl/status/statusor.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait {

enum PlanFileEncoding {
kBinary = 0,
kJson = 1,
kProtoText = 2,
kText = 3,
};

// Loads a Substrait plan consisting of any encoding type from the given file.
absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding(
std::string_view input_filename);

// Writes the provided plan to the specified location with the specified
// encoding type.
[[maybe_unused]] absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileEncoding encoding);

} // namespace io::substrait
13 changes: 12 additions & 1 deletion src/substrait/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

add_library(substrait_common Exceptions.cpp)

target_link_libraries(substrait_common fmt::fmt-header-only)

add_library(substrait_io Io.cpp)
add_dependencies(
substrait_io
substrait_proto
substrait_textplan_converter
substrait_textplan_loader
fmt::fmt-header-only
absl::status
absl::statusor)
target_link_libraries(substrait_io substrait_proto substrait_textplan_converter
substrait_textplan_loader absl::status absl::statusor)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
endif()
77 changes: 77 additions & 0 deletions src/substrait/common/Io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <regex>
#include <string_view>

#include "substrait/proto/plan.pb.h"
#include "substrait/textplan/converter/LoadBinary.h"
#include "substrait/textplan/converter/SaveBinary.h"
#include "substrait/textplan/parser/LoadText.h"

namespace io::substrait {

namespace {

const std::regex kIsJson(R"(("extensionUris"|"extensions"|"relations"))");
const std::regex kIsProtoText(
R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))");
const std::regex kIsText(
R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)");

PlanFileEncoding detectEncoding(std::string_view content) {
if (std::regex_search(content.begin(), content.end(), kIsJson)) {
return kJson;
}
if (std::regex_search(content.begin(), content.end(), kIsProtoText)) {
return kProtoText;
}
if (std::regex_search(content.begin(), content.end(), kIsText)) {
return kText;
}
return kBinary;
}

} // namespace

absl::StatusOr<::substrait::proto::Plan> loadPlanWithUnknownEncoding(
std::string_view input_filename) {
auto contentOrError = textplan::readFromFile(input_filename.data());
if (!contentOrError.ok()) {
return contentOrError.status();
}

auto encoding = detectEncoding(*contentOrError);
absl::StatusOr<::substrait::proto::Plan> planOrError;
switch (encoding) {
case kBinary:
return textplan::loadFromBinary(*contentOrError);
case kJson:
return textplan::loadFromJson(*contentOrError);
case kProtoText:
return textplan::loadFromProtoText(*contentOrError);
case kText:
return textplan::loadFromText(*contentOrError);
}
return absl::UnimplementedError("Unexpected encoding requested.");
}

absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileEncoding encoding) {
switch (encoding) {
case kBinary:
return textplan::savePlanToBinary(plan, output_filename);
case kJson:
return textplan::savePlanToJson(plan, output_filename);
case kProtoText:
return textplan::savePlanToProtoText(plan, output_filename);
case kText:
return textplan::savePlanToText(plan, output_filename);
}
return absl::UnimplementedError("Unexpected encoding requested.");
}

} // namespace io::substrait
10 changes: 10 additions & 0 deletions src/substrait/common/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ add_test_case(
substrait_common
gtest
gtest_main)

add_test_case(
substrait_io_test
SOURCES
IoTest.cpp
EXTRA_LINK_LIBS
substrait_io
protobuf-matchers
gtest
gtest_main)
104 changes: 104 additions & 0 deletions src/substrait/common/tests/IoTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <protobuf-matchers/protocol-buffer-matchers.h>
#include <unistd.h>

using ::protobuf_matchers::EqualsProto;
using ::protobuf_matchers::Partially;

namespace io::substrait {

namespace {

constexpr const char* planFileEncodingToString(PlanFileEncoding e) noexcept {
switch (e) {
case PlanFileEncoding::kBinary:
return "kBinary";
case PlanFileEncoding::kJson:
return "kJson";
case PlanFileEncoding::kProtoText:
return "kProtoText";
case PlanFileEncoding::kText:
return "kText";
}
return "IMPOSSIBLE";
}

} // namespace

class IoTest : public ::testing::Test {};

TEST_F(IoTest, LoadMissingFile) {
auto result =
::io::substrait::loadPlanWithUnknownEncoding("non-existent-file");
ASSERT_FALSE(result.ok());
ASSERT_THAT(
result.status().message(),
::testing::ContainsRegex("Failed to open file non-existent-file"));
}

class SaveAndLoadTestFixture
: public ::testing::TestWithParam<PlanFileEncoding> {
public:
~SaveAndLoadTestFixture() override {
for (const auto& filename : testFiles_) {
unlink(filename.c_str());
}
}

void registerCleanup(const char* filename) {
testFiles_.emplace_back(filename);
}

private:
std::vector<std::string> testFiles_;
};

TEST_P(SaveAndLoadTestFixture, SaveAndLoad) {
auto tempFilename = std::tmpnam(nullptr);
registerCleanup(tempFilename);
PlanFileEncoding encoding = GetParam();

::substrait::proto::Plan plan;
auto root = plan.add_relations()->mutable_root();
auto read = root->mutable_input()->mutable_read();
read->mutable_common()->mutable_direct();
read->mutable_named_table()->add_names("table_name");
auto status = ::io::substrait::savePlan(plan, tempFilename, encoding);
ASSERT_TRUE(status.ok()) << "Save failed.\n" << status;

auto result = ::io::substrait::loadPlanWithUnknownEncoding(tempFilename);
ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status();
ASSERT_THAT(
*result,
Partially(EqualsProto<::substrait::proto::Plan>(
R"(relations {
root {
input {
read {
common {
direct {
}
}
named_table {
names: "table_name"
}
}
}
}
})")));
}

INSTANTIATE_TEST_SUITE_P(
SaveAndLoadTests,
SaveAndLoadTestFixture,
testing::Values(kBinary, kJson, kProtoText, kText),
[](const testing::TestParamInfo<SaveAndLoadTestFixture::ParamType>& info) {
return planFileEncodingToString(info.param);
});

} // namespace io::substrait
19 changes: 17 additions & 2 deletions src/substrait/textplan/StringManipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,33 @@

#include "StringManipulation.h"

#include <numeric>
#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(0, needle.size()) == needle;
}

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(haystack.size() - needle.size(), needle.size()) == needle;
}

std::string joinLines(
std::vector<std::string> lines,
std::string_view separator) {
auto concatWithSeparator = [separator](std::string a, const std::string& b) {
return std::move(a) + std::string(separator) + b;
};

auto result = std::accumulate(
std::next(lines.begin()), lines.end(), lines[0], concatWithSeparator);
return result;
}

} // namespace io::substrait::textplan
7 changes: 7 additions & 0 deletions src/substrait/textplan/StringManipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#pragma once

#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

Expand All @@ -12,4 +14,9 @@ bool startsWith(std::string_view haystack, std::string_view needle);
// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle);

// Joins a vector of strings into a single string separated by separator.
std::string joinLines(
std::vector<std::string> lines,
std::string_view separator = "\n");

} // namespace io::substrait::textplan
25 changes: 21 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) {
auto subtype = ANY_CAST(SourceType, info.subtype);
switch (subtype) {
case SourceType::kNamedTable: {
auto table =
ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob);
if (info.blob.has_value()) {
// We are using the proto as is in lieu of a disciplined structure.
auto table = ANY_CAST(
const ::substrait::proto::ReadRel_NamedTable*, info.blob);
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;
break;
}
// We are using the new style data structure.
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
for (const auto& sym :
symbolTable.lookupSymbolsByLocation(info.location)) {
if (sym->type == SymbolType::kSourceDetail) {
text << " \"" << sym->name << "\",\n";
}
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;

break;
}
case SourceType::kLocalFiles: {
Expand Down
8 changes: 7 additions & 1 deletion src/substrait/textplan/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ set(TEXTPLAN_SRCS
PlanPrinterVisitor.h
LoadBinary.cpp
LoadBinary.h
SaveBinary.cpp
SaveBinary.h
ParseBinary.cpp
ParseBinary.h)

Expand All @@ -20,10 +22,14 @@ target_link_libraries(
substrait_textplan_converter
substrait_common
substrait_expression
substrait_io
substrait_proto
symbol_table
error_listener
date::date)
date::date
fmt::fmt-header-only
absl::status
absl::statusor)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
Expand Down
Loading

0 comments on commit 2dfee3b

Please sign in to comment.