Skip to content

Commit

Permalink
feat: add api for loading plans of all types (substrait-io#80)
Browse files Browse the repository at this point in the history
features:
    * supports reading/writing all of the major formats

caveats:
* only read/writes to filenames so in-memory use should use other
interfaces
    * does not support compression
    * does not have a zero copy interface
  • Loading branch information
EpsilonPrime committed Aug 26, 2023
1 parent cf6052b commit ecfd5b8
Show file tree
Hide file tree
Showing 23 changed files with 574 additions and 91 deletions.
60 changes: 60 additions & 0 deletions include/substrait/common/Io.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string_view>

#include "absl/status/statusor.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait {

/*
* \brief The four different ways plans can be represented on disk.
*/
enum class PlanFileFormat {
kBinary = 0,
kJson = 1,
kProtoText = 2,
kText = 3,
};

/*
* \brief Loads a Substrait plan of any format from the given file.
*
* loadPlan determines which file type the specified file is and then calls
* the appropriate load/parse method to consume it preserving any error
* messages.
*
* This will load the plan into memory and then convert it consuming twice the
* amount of memory that it consumed on disk.
*
* \param input_filename The filename containing the plan to convert.
* \return If loading was successful, returns a plan. If loading was not
* successful this is a status containing a list of parse errors in the status's
* message.
*/
absl::StatusOr<::substrait::proto::Plan> loadPlan(
std::string_view input_filename);

/*
* \brief Writes the provided plan to disk.
*
* savePlan writes the provided plan in the specified format to the specified
* location.
*
* This routine will consume more memory during the conversion to the text
* format as the original plan as well as the annotated parse tree will need to
* reside in memory during the process.
*
* \param plan
* \param output_filename
* \param format
* \return
*/
absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileFormat format);

} // namespace io::substrait
13 changes: 12 additions & 1 deletion src/substrait/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

add_library(substrait_common Exceptions.cpp)

target_link_libraries(substrait_common fmt::fmt-header-only)

add_library(substrait_io Io.cpp)
add_dependencies(
substrait_io
substrait_proto
substrait_textplan_converter
substrait_textplan_loader
fmt::fmt-header-only
absl::status
absl::statusor)
target_link_libraries(substrait_io substrait_proto substrait_textplan_converter
substrait_textplan_loader absl::status absl::statusor)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
endif()
77 changes: 77 additions & 0 deletions src/substrait/common/Io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <regex>
#include <string_view>

#include "substrait/proto/plan.pb.h"
#include "substrait/textplan/converter/LoadBinary.h"
#include "substrait/textplan/converter/SaveBinary.h"
#include "substrait/textplan/parser/LoadText.h"

namespace io::substrait {

namespace {

const std::regex kIsJson(
R"(("extensionUris"|"extension_uris"|"extensions"|"relations"))");
const std::regex kIsProtoText(
R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))");
const std::regex kIsText(
R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)");

PlanFileFormat detectFormat(std::string_view content) {
if (std::regex_search(content.begin(), content.end(), kIsJson)) {
return PlanFileFormat::kJson;
}
if (std::regex_search(content.begin(), content.end(), kIsProtoText)) {
return PlanFileFormat::kProtoText;
}
if (std::regex_search(content.begin(), content.end(), kIsText)) {
return PlanFileFormat::kText;
}
return PlanFileFormat::kBinary;
}

} // namespace

absl::StatusOr<::substrait::proto::Plan> loadPlan(
std::string_view input_filename) {
auto contentOrError = textplan::readFromFile(input_filename.data());
if (!contentOrError.ok()) {
return contentOrError.status();
}

auto encoding = detectFormat(*contentOrError);
absl::StatusOr<::substrait::proto::Plan> planOrError;
switch (encoding) {
case PlanFileFormat::kBinary:
return textplan::loadFromBinary(*contentOrError);
case PlanFileFormat::kJson:
return textplan::loadFromJson(*contentOrError);
case PlanFileFormat::kProtoText:
return textplan::loadFromProtoText(*contentOrError);
case PlanFileFormat::kText:
return textplan::loadFromText(*contentOrError);
}
}

absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileFormat format) {
switch (format) {
case PlanFileFormat::kBinary:
return textplan::savePlanToBinary(plan, output_filename);
case PlanFileFormat::kJson:
return textplan::savePlanToJson(plan, output_filename);
case PlanFileFormat::kProtoText:
return textplan::savePlanToProtoText(plan, output_filename);
case PlanFileFormat::kText:
return textplan::savePlanToText(plan, output_filename);
}
return absl::UnimplementedError("Unexpected format requested.");
}

} // namespace io::substrait
10 changes: 10 additions & 0 deletions src/substrait/common/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ add_test_case(
substrait_common
gtest
gtest_main)

add_test_case(
substrait_io_test
SOURCES
IoTest.cpp
EXTRA_LINK_LIBS
substrait_io
protobuf-matchers
gtest
gtest_main)
120 changes: 120 additions & 0 deletions src/substrait/common/tests/IoTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <filesystem>

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <protobuf-matchers/protocol-buffer-matchers.h>
#include <unistd.h>

using ::protobuf_matchers::EqualsProto;
using ::protobuf_matchers::Partially;

namespace io::substrait {

namespace {

constexpr const char* planFileEncodingToString(PlanFileFormat e) noexcept {
switch (e) {
case PlanFileFormat::kBinary:
return "kBinary";
case PlanFileFormat::kJson:
return "kJson";
case PlanFileFormat::kProtoText:
return "kProtoText";
case PlanFileFormat::kText:
return "kText";
}
return "IMPOSSIBLE";
}

} // namespace

class IoTest : public ::testing::Test {};

TEST_F(IoTest, LoadMissingFile) {
auto result = ::io::substrait::loadPlan("non-existent-file");
ASSERT_FALSE(result.ok());
ASSERT_THAT(
result.status().message(),
::testing::ContainsRegex("Failed to open file non-existent-file"));
}

class SaveAndLoadTestFixture : public ::testing::TestWithParam<PlanFileFormat> {
public:
void SetUp() override {
testFileDirectory_ = std::filesystem::temp_directory_path() /
std::filesystem::path("my_temp_dir");

if (!std::filesystem::create_directory(testFileDirectory_)) {
ASSERT_TRUE(false) << "Failed to create temporary directory.";
testFileDirectory_.clear();
}
}

void TearDown() override {
if (!testFileDirectory_.empty()) {
std::error_code err;
std::filesystem::remove_all(testFileDirectory_, err);
ASSERT_FALSE(err) << err.message();
}
}

static std::string makeTempFileName() {
static int tempFileNum = 0;
return "testfile" + std::to_string(++tempFileNum);
}

protected:
std::string testFileDirectory_;
};

TEST_P(SaveAndLoadTestFixture, SaveAndLoad) {
auto tempFilename = testFileDirectory_ + "/" + makeTempFileName();
PlanFileFormat encoding = GetParam();

::substrait::proto::Plan plan;
auto root = plan.add_relations()->mutable_root();
auto read = root->mutable_input()->mutable_read();
read->mutable_common()->mutable_direct();
read->mutable_named_table()->add_names("table_name");
auto status = ::io::substrait::savePlan(plan, tempFilename, encoding);
ASSERT_TRUE(status.ok()) << "Save failed.\n" << status;

auto result = ::io::substrait::loadPlan(tempFilename);
ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status();
ASSERT_THAT(
*result,
Partially(EqualsProto<::substrait::proto::Plan>(
R"(relations {
root {
input {
read {
common {
direct {
}
}
named_table {
names: "table_name"
}
}
}
}
})")));
}

INSTANTIATE_TEST_SUITE_P(
SaveAndLoadTests,
SaveAndLoadTestFixture,
testing::Values(
PlanFileFormat::kBinary,
PlanFileFormat::kJson,
PlanFileFormat::kProtoText,
PlanFileFormat::kText),
[](const testing::TestParamInfo<SaveAndLoadTestFixture::ParamType>& info) {
return planFileEncodingToString(info.param);
});

} // namespace io::substrait
4 changes: 2 additions & 2 deletions src/substrait/textplan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ add_library(error_listener SubstraitErrorListener.cpp SubstraitErrorListener.h)

add_library(parse_result ParseResult.cpp ParseResult.h)

add_dependencies(symbol_table substrait_proto substrait_common
add_dependencies(symbol_table substrait_proto substrait_common absl::strings
fmt::fmt-header-only)

target_link_libraries(symbol_table fmt::fmt-header-only
target_link_libraries(symbol_table fmt::fmt-header-only absl::strings
substrait_textplan_converter)

# Provide access to the generated protobuffer headers hierarchy.
Expand Down
7 changes: 5 additions & 2 deletions src/substrait/textplan/StringManipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

#include "StringManipulation.h"

#include <numeric>
#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(0, needle.size()) == needle;
}

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(haystack.size() - needle.size(), needle.size()) == needle;
Expand Down
2 changes: 2 additions & 0 deletions src/substrait/textplan/StringManipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#pragma once

#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

Expand Down
25 changes: 21 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) {
auto subtype = ANY_CAST(SourceType, info.subtype);
switch (subtype) {
case SourceType::kNamedTable: {
auto table =
ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob);
if (info.blob.has_value()) {
// We are using the proto as is in lieu of a disciplined structure.
auto table = ANY_CAST(
const ::substrait::proto::ReadRel_NamedTable*, info.blob);
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;
break;
}
// We are using the new style data structure.
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
for (const auto& sym :
symbolTable.lookupSymbolsByLocation(info.location)) {
if (sym->type == SymbolType::kSourceDetail) {
text << " \"" << sym->name << "\",\n";
}
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;

break;
}
case SourceType::kLocalFiles: {
Expand Down
Loading

0 comments on commit ecfd5b8

Please sign in to comment.