Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add root names to the textplan #76

Merged
merged 7 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/substrait/textplan/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum class SymbolType {
kSource = 7,
kSourceDetail = 8,
kField = 9,
kRoot = 10,

kUnknown = -1,
};
Expand Down
76 changes: 68 additions & 8 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,44 @@ std::string outputRelationsSection(const SymbolTable& symbolTable) {
return text.str();
}

std::string outputRootSection(const SymbolTable& symbolTable) {
std::stringstream text;
bool hasPreviousText = false;
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kRoot) {
continue;
}
auto names = ANY_CAST(std::vector<std::string>, info.blob);
if (names.empty()) {
// No point in printing an empty section.
continue;
}
if (hasPreviousText) {
text << "\n";
}
text << "root {"
<< "\n";
text << " names = [";
bool hadName = false;
for (const auto& name : names) {
if (hadName) {
text << ",\n";
} else {
text << "\n";
}
text << " " << name;
hadName = true;
}
if (hadName) {
text << "\n";
}
text << " ]\n";
text << "}\n";
hasPreviousText = true;
}
return text.str();
}

std::string outputSchemaSection(const SymbolTable& symbolTable) {
std::stringstream text;
bool hasPreviousText = false;
Expand Down Expand Up @@ -427,6 +465,15 @@ std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) {
hasPreviousText = true;
}

newText = outputRootSection(symbolTable);
if (!newText.empty()) {
if (hasPreviousText) {
text << "\n";
}
text << newText;
hasPreviousText = true;
}

newText = outputSchemaSection(symbolTable);
if (!newText.empty()) {
if (hasPreviousText) {
Expand Down Expand Up @@ -668,14 +715,27 @@ ::substrait::proto::Plan SymbolTablePrinter::outputToBinaryPlan(
if (relationData->newPipelines.empty()) {
*relation->mutable_root()->mutable_input() = relationData->relation;
} else {
// This is a root node, copy the first node in before iterating.
auto inputRelationData = ANY_CAST(
std::shared_ptr<RelationData>, relationData->newPipelines[0]->blob);
*relation->mutable_root()->mutable_input() = inputRelationData->relation;

addInputsToRelation(
*relationData->newPipelines[0],
relation->mutable_root()->mutable_input());
if (relationData->newPipelines[0]->type != SymbolType::kRoot) {
// This is a root node, copy the first node in before iterating.
auto inputRelationData = ANY_CAST(
std::shared_ptr<RelationData>, relationData->newPipelines[0]->blob);
*relation->mutable_root()->mutable_input() =
inputRelationData->relation;

addInputsToRelation(
*relationData->newPipelines[0],
relation->mutable_root()->mutable_input());
}

const auto& rootSymbol =
symbolTable.nthSymbolByType(0, SymbolType::kRoot);
if (rootSymbol != SymbolInfo::kUnknown) {
const auto& rootNames =
ANY_CAST(std::vector<std::string>, rootSymbol.blob);
for (const auto& name : rootNames) {
relation->mutable_root()->add_names(name);
}
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace io::substrait::textplan {

namespace {

const std::string kRootNames{"root.names"};

std::string shortName(std::string str) {
auto loc = str.find(':');
if (loc != std::string::npos) {
Expand Down Expand Up @@ -167,6 +169,16 @@ std::any InitialPlanProtoVisitor::visitRelation(

std::any InitialPlanProtoVisitor::visitRelationRoot(
const ::substrait::proto::RelRoot& relation) {
std::vector<std::string> names;
names.insert(names.end(), relation.names().begin(), relation.names().end());
auto uniqueName = symbolTable_->getUniqueName(kRootNames);
symbolTable_->defineSymbol(
uniqueName,
PROTO_LOCATION(relation),
SymbolType::kRoot,
SourceType::kUnknown,
names);

BasePlanProtoVisitor::visitRelationRoot(relation);
return std::nullopt;
}
Expand Down
1 change: 0 additions & 1 deletion src/substrait/textplan/converter/PipelineVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ std::shared_ptr<RelationData> PipelineVisitor::getRelationData(
std::any PipelineVisitor::visitRelation(
const ::substrait::proto::Rel& relation) {
auto relationData = getRelationData(relation);

switch (relation.rel_type_case()) {
case ::substrait::proto::Rel::RelTypeCase::kRead:
// No relations beyond this one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ std::vector<TestCase> getTestCases() {
}
})",
AllOf(
HasSymbols({"local", "read", "root"}),
HasSymbols({"root.names", "local", "read", "root"}),
WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
read -> root;
Expand Down Expand Up @@ -190,7 +190,14 @@ std::vector<TestCase> getTestCases() {
})",
AllOf(
HasSymbols(
{"schema", "cost", "count", "named", "#2", "read", "root"}),
{"root.names",
"schema",
"cost",
"count",
"named",
"#2",
"read",
"root"}),
WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
read -> root;
Expand Down Expand Up @@ -401,7 +408,7 @@ std::vector<TestCase> getTestCases() {
}
})",
AllOf(
HasSymbols({"filter", "root"}),
HasSymbols({"root.names", "filter", "root"}),
WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
filter -> root;
Expand Down Expand Up @@ -454,7 +461,7 @@ std::vector<TestCase> getTestCases() {
}
})",
AllOf(
HasSymbols({"filter", "root"}),
HasSymbols({"root.names", "filter", "root"}),
WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
filter -> root;
Expand Down Expand Up @@ -526,7 +533,7 @@ std::vector<TestCase> getTestCases() {
}
})",
AllOf(
HasSymbols({"filter", "root"}),
HasSymbols({"root.names", "filter", "root"}),
WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
filter -> root;
Expand All @@ -539,18 +546,20 @@ std::vector<TestCase> getTestCases() {
{
"single three node pipeline",
"relations: { root: { input: { project: { input { read: { local_files {} } } } } } }",
HasSymbols({"local", "read", "project", "root"}),
HasSymbols({"root.names", "local", "read", "project", "root"}),
},
{
"two identical three node pipelines",
"relations: { root: { input: { project: { input { read: { local_files {} } } } } } }"
"relations: { root: { input: { project: { input { read: { local_files {} } } } } } }",
AllOf(
HasSymbols(
{"local",
{"root.names",
"local",
"read",
"project",
"root",
"root.names2",
"local2",
"read2",
"project2",
Expand All @@ -566,7 +575,13 @@ std::vector<TestCase> getTestCases() {
"relations: { root: { input: { hash_join: { left { read: { local_files {} } } right { read: { local_files {} } } } } } }",
AllOf(
HasSymbols(
{"local", "read", "local2", "read2", "hashjoin", "root"}),
{"root.names",
"local",
"read",
"local2",
"read2",
"hashjoin",
"root"}),
WhenSerialized(::testing::HasSubstr("pipelines {\n"
" read -> hashjoin;\n"
" read2 -> hashjoin;\n"
Expand Down
9 changes: 9 additions & 0 deletions src/substrait/textplan/parser/ParseText.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ ParseResult parseStream(antlr4::ANTLRInputStream stream) {
*visitor->getSymbolTable(), visitor->getErrorListener());
try {
pipelineVisitor->visitPlan(tree);
} catch (std::invalid_argument ex) {
// Catches the any_cast exception and logs a useful error message.
errorListener.syntaxError(
&parser,
nullptr,
/*line=*/1,
/*charPositionInLine=*/1,
ex.what(),
std::current_exception());
} catch (...) {
errorListener.syntaxError(
&parser,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline(

// Refetch our symbol table entry to make sure we have the latest version.
auto* symbol = symbolTable_->lookupSymbolByName(relationName);
if (symbol->blob.type() != typeid(std::shared_ptr<RelationData>)) {
return defaultResult();
}
auto relationData = ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);

// Check for accidental cross-pipeline use.
Expand All @@ -103,6 +106,10 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline(
}
const SymbolInfo* rightmostSymbol = rightSymbol;
if (*rightSymbol != SymbolInfo::kUnknown) {
if (rightSymbol->blob.type() != typeid(std::shared_ptr<RelationData>)) {
errorListener_->addError(
ctx->getStart(), "No relation definition present for this symbol.");
}
auto rightRelationData =
ANY_CAST(std::shared_ptr<RelationData>, rightSymbol->blob);
if (rightRelationData->pipelineStart != nullptr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ std::any SubstraitPlanRelationVisitor::visitRelation(
// This error has been previously dealt with thus we can safely skip it.
return defaultResult();
}
if (symbol->type == SymbolType::kRoot) {
return defaultResult();
}
auto relationData = ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);
::substrait::proto::Rel relation;

Expand Down
25 changes: 25 additions & 0 deletions src/substrait/textplan/parser/SubstraitPlanVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
#include "substrait/textplan/Finally.h"
#include "substrait/textplan/Location.h"
#include "substrait/textplan/StructuredSymbolData.h"
#include "substrait/textplan/SymbolTable.h"
#include "substrait/type/Type.h"

namespace io::substrait::textplan {

const std::string kRootName{"root"};

// Removes leading and trailing quotation marks.
std::string extractFromString(std::string s) {
if (s.size() < 2) {
Expand Down Expand Up @@ -158,6 +161,28 @@ std::any SubstraitPlanVisitor::visitSchema_item(
visitLiteral_complex_type(ctx->literal_complex_type()));
}

std::any SubstraitPlanVisitor::visitRoot_relation(
SubstraitPlanParser::Root_relationContext* ctx) {
auto prevRoot = symbolTable_->lookupSymbolByName(kRootName);
if (prevRoot != nullptr) {
if (prevRoot->type == SymbolType::kRoot) {
errorListener_->addError(
ctx->getStart(), "A root relation was already defined.");
} else {
errorListener_->addError(
ctx->getStart(), "A relation named root was already defined.");
}
return nullptr;
}
std::vector<std::string> names;
for (const auto& id : ctx->id()) {
names.push_back(id->getText());
}
symbolTable_->defineSymbol(
kRootName, Location(ctx), SymbolType::kRoot, SourceType::kUnknown, names);
return nullptr;
}

std::any SubstraitPlanVisitor::visitRelation(
SubstraitPlanParser::RelationContext* ctx) {
auto relType =
Expand Down
2 changes: 2 additions & 0 deletions src/substrait/textplan/parser/SubstraitPlanVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class SubstraitPlanVisitor : public SubstraitPlanTypeVisitor {
std::any visitSchema_item(
SubstraitPlanParser::Schema_itemContext* ctx) override;
std::any visitRelation(SubstraitPlanParser::RelationContext* ctx) override;
std::any visitRoot_relation(
SubstraitPlanParser::Root_relationContext* ctx) override;
std::any visitRelation_type(
SubstraitPlanParser::Relation_typeContext* ctx) override;
std::any visitSource_definition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ NAMED_TABLE: 'NAMED_TABLE';
EXTENSION_TABLE: 'EXTENSION_TABLE';

SOURCE: 'SOURCE';
ROOT: 'ROOT';
ITEMS: 'ITEMS';
NAMES: 'NAMES';
URI_FILE: 'URI_FILE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ plan
plan_detail
: pipelines
| relation
| root_relation
| schema_definition
| source_definition
| extensionspace
Expand All @@ -42,6 +43,10 @@ relation
: relation_type RELATION relation_ref LEFTBRACE relation_detail* RIGHTBRACE
;

root_relation
: ROOT LEFTBRACE NAMES EQUAL LEFTBRACKET id (COMMA id)* COMMA? RIGHTBRACKET RIGHTBRACE
;

relation_type
: id
;
Expand Down Expand Up @@ -209,6 +214,8 @@ id
simple_id
: IDENTIFIER
| FILTER
| ROOT
| SOURCE
| SCHEMA
| NULLVAL
| SORT
Expand Down
Loading