Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add api for loading plans of all types #80

Merged
merged 11 commits into from
Aug 26, 2023
Merged
60 changes: 60 additions & 0 deletions include/substrait/common/Io.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string_view>

#include "absl/status/statusor.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait {

/*
* \brief The four different ways plans can be represented on disk.
*/
enum class PlanFileFormat {
kBinary = 0,
kJson = 1,
kProtoText = 2,
kText = 3,
};

/*
* \\brief Loads a Substrait plan of any format from the given file.
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
*
* loadPlan determines which file type the specified file is and then calls
* the appropriate load/parse method to consume it preserving any error
* messages.
*
* This will load the plan into memory and then convert it consuming twice the
* amount of memory that it consumed on disk.
*
* \\param input_filename The filename containing the plan to convert.
* \\return If loading was successful, returns a plan. If loading was not
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
* successful this is a status containing a list of parse errors in the status's
* message.
*/
absl::StatusOr<::substrait::proto::Plan> loadPlan(
std::string_view input_filename);

/*
* \\brief Writes the provided plan to disk.
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
*
* savePlan writes the provided plan in the specified format to the specified
* location.
*
* This routine will consume more memory during the conversion to the text
* format as the original plan as well as the annotated parse tree will need to
* reside in memory during the process.
*
* \\param plan
* \\param output_filename
* \\param format
* \\return
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
*/
absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileFormat format);

} // namespace io::substrait
13 changes: 12 additions & 1 deletion src/substrait/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

add_library(substrait_common Exceptions.cpp)

target_link_libraries(substrait_common fmt::fmt-header-only)

add_library(substrait_io Io.cpp)
add_dependencies(
substrait_io
substrait_proto
substrait_textplan_converter
substrait_textplan_loader
fmt::fmt-header-only
absl::status
absl::statusor)
target_link_libraries(substrait_io substrait_proto substrait_textplan_converter
substrait_textplan_loader absl::status absl::statusor)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
endif()
78 changes: 78 additions & 0 deletions src/substrait/common/Io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <regex>
#include <string_view>

#include "substrait/proto/plan.pb.h"
#include "substrait/textplan/converter/LoadBinary.h"
#include "substrait/textplan/converter/SaveBinary.h"
#include "substrait/textplan/parser/LoadText.h"

namespace io::substrait {

namespace {

const std::regex kIsJson(
R"(("extensionUris"|"extension_uris"|"extensions"|"relations"))");
const std::regex kIsProtoText(
R"((^|\n)((relations|extensions|extension_uris|expected_type_urls) \{))");
const std::regex kIsText(
R"((^|\n) *(pipelines|[a-z]+ *relation|schema|source|extension_space) *)");

PlanFileFormat detectFormat(std::string_view content) {
if (std::regex_search(content.begin(), content.end(), kIsJson)) {
return PlanFileFormat::kJson;
}
if (std::regex_search(content.begin(), content.end(), kIsProtoText)) {
return PlanFileFormat::kProtoText;
}
if (std::regex_search(content.begin(), content.end(), kIsText)) {
return PlanFileFormat::kText;
}
return PlanFileFormat::kBinary;
}

} // namespace

absl::StatusOr<::substrait::proto::Plan> loadPlan(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the load and save methods return StatusOr then I, as a user, would not expect to get an exception. I think we still rely on exceptions in some places (maybe this isn't true?) If so, we should probably wrap these methods with a try/catch and wrap the exception in an invalid status.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed most of the cases where we use SUBSTRAIT_FAIL. There are about 5 remaining and most of those are for nearly impossible cases. I'd prefer to finish removing them over wrapping this code in an exception. If it turns out we're getting exceptions from other libraries then we will be forced to add it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And the fuzz testing work will help illuminate if we need to wrap the code.)

std::string_view input_filename) {
auto contentOrError = textplan::readFromFile(input_filename.data());
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
if (!contentOrError.ok()) {
return contentOrError.status();
}

auto encoding = detectFormat(*contentOrError);
absl::StatusOr<::substrait::proto::Plan> planOrError;
switch (encoding) {
case PlanFileFormat::kBinary:
return textplan::loadFromBinary(*contentOrError);
case PlanFileFormat::kJson:
return textplan::loadFromJson(*contentOrError);
case PlanFileFormat::kProtoText:
return textplan::loadFromProtoText(*contentOrError);
case PlanFileFormat::kText:
return textplan::loadFromText(*contentOrError);
}
return absl::UnimplementedError("Unexpected encoding requested.");
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
}

absl::Status savePlan(
const ::substrait::proto::Plan& plan,
std::string_view output_filename,
PlanFileFormat format) {
switch (format) {
case PlanFileFormat::kBinary:
return textplan::savePlanToBinary(plan, output_filename);
case PlanFileFormat::kJson:
return textplan::savePlanToJson(plan, output_filename);
case PlanFileFormat::kProtoText:
return textplan::savePlanToProtoText(plan, output_filename);
case PlanFileFormat::kText:
return textplan::savePlanToText(plan, output_filename);
}
return absl::UnimplementedError("Unexpected format requested.");
}

} // namespace io::substrait
10 changes: 10 additions & 0 deletions src/substrait/common/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ add_test_case(
substrait_common
gtest
gtest_main)

add_test_case(
substrait_io_test
SOURCES
IoTest.cpp
EXTRA_LINK_LIBS
substrait_io
protobuf-matchers
gtest
gtest_main)
106 changes: 106 additions & 0 deletions src/substrait/common/tests/IoTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/common/Io.h"

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <protobuf-matchers/protocol-buffer-matchers.h>
#include <unistd.h>

using ::protobuf_matchers::EqualsProto;
using ::protobuf_matchers::Partially;

namespace io::substrait {

namespace {

constexpr const char* planFileEncodingToString(PlanFileFormat e) noexcept {
switch (e) {
case PlanFileFormat::kBinary:
return "kBinary";
case PlanFileFormat::kJson:
return "kJson";
case PlanFileFormat::kProtoText:
return "kProtoText";
case PlanFileFormat::kText:
return "kText";
}
return "IMPOSSIBLE";
}

} // namespace

class IoTest : public ::testing::Test {};

TEST_F(IoTest, LoadMissingFile) {
auto result = ::io::substrait::loadPlan("non-existent-file");
ASSERT_FALSE(result.ok());
ASSERT_THAT(
result.status().message(),
::testing::ContainsRegex("Failed to open file non-existent-file"));
}

class SaveAndLoadTestFixture : public ::testing::TestWithParam<PlanFileFormat> {
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
public:
~SaveAndLoadTestFixture() override {
for (const auto& filename : testFiles_) {
unlink(filename.c_str());
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
}
}

void registerCleanup(const char* filename) {
testFiles_.emplace_back(filename);
}

private:
std::vector<std::string> testFiles_;
};

TEST_P(SaveAndLoadTestFixture, SaveAndLoad) {
auto tempFilename = std::tmpnam(nullptr);
registerCleanup(tempFilename);
PlanFileFormat encoding = GetParam();

::substrait::proto::Plan plan;
auto root = plan.add_relations()->mutable_root();
auto read = root->mutable_input()->mutable_read();
read->mutable_common()->mutable_direct();
read->mutable_named_table()->add_names("table_name");
auto status = ::io::substrait::savePlan(plan, tempFilename, encoding);
ASSERT_TRUE(status.ok()) << "Save failed.\n" << status;

auto result = ::io::substrait::loadPlan(tempFilename);
ASSERT_TRUE(result.ok()) << "Load failed.\n" << result.status();
ASSERT_THAT(
*result,
Partially(EqualsProto<::substrait::proto::Plan>(
R"(relations {
root {
input {
read {
common {
direct {
}
}
named_table {
names: "table_name"
}
}
}
}
})")));
}

INSTANTIATE_TEST_SUITE_P(
SaveAndLoadTests,
SaveAndLoadTestFixture,
testing::Values(
PlanFileFormat::kBinary,
PlanFileFormat::kJson,
PlanFileFormat::kProtoText,
PlanFileFormat::kText),
[](const testing::TestParamInfo<SaveAndLoadTestFixture::ParamType>& info) {
return planFileEncodingToString(info.param);
});

} // namespace io::substrait
19 changes: 17 additions & 2 deletions src/substrait/textplan/StringManipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,33 @@

#include "StringManipulation.h"

#include <numeric>
#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(0, needle.size()) == needle;
}

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(haystack.size() - needle.size(), needle.size()) == needle;
}

std::string joinLines(
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::string> lines,
std::string_view separator) {
auto concatWithSeparator = [separator](std::string a, const std::string& b) {
return std::move(a) + std::string(separator) + b;
};

auto result = std::accumulate(
std::next(lines.begin()), lines.end(), lines[0], concatWithSeparator);
return result;
}

} // namespace io::substrait::textplan
7 changes: 7 additions & 0 deletions src/substrait/textplan/StringManipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#pragma once

#include <string>
#include <string_view>
#include <vector>

namespace io::substrait::textplan {

Expand All @@ -12,4 +14,9 @@ bool startsWith(std::string_view haystack, std::string_view needle);
// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle);

// Joins a vector of strings into a single string separated by separator.
std::string joinLines(
std::vector<std::string> lines,
std::string_view separator = "\n");

} // namespace io::substrait::textplan
25 changes: 21 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,33 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) {
auto subtype = ANY_CAST(SourceType, info.subtype);
switch (subtype) {
case SourceType::kNamedTable: {
auto table =
ANY_CAST(const ::substrait::proto::ReadRel_NamedTable*, info.blob);
if (info.blob.has_value()) {
// We are using the proto as is in lieu of a disciplined structure.
auto table = ANY_CAST(
const ::substrait::proto::ReadRel_NamedTable*, info.blob);
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;
break;
}
// We are using the new style data structure.
text << "source named_table " << info.name << " {\n";
text << " names = [\n";
for (const auto& name : table->names()) {
text << " \"" << name << "\",\n";
for (const auto& sym :
symbolTable.lookupSymbolsByLocation(info.location)) {
if (sym->type == SymbolType::kSourceDetail) {
text << " \"" << sym->name << "\",\n";
}
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;

break;
}
case SourceType::kLocalFiles: {
Expand Down
Loading