From f7a45a649e87537c0ecf48bc4dae8a2a29fb9f3e Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 28 Jun 2023 15:07:09 -0700 Subject: [PATCH] feat: add root names to the textplan (#76) Root names are added to both the parser and converter in this PR. The root relation is not being treated as a relation internally and instead is treated merely as annotation containing a list of names. This is primarily because the relation-related codepaths make assumptions that wouldn't apply to the root (such as having a valid Relation proto as its data type). --- src/substrait/textplan/SymbolTable.h | 1 + src/substrait/textplan/SymbolTablePrinter.cpp | 76 +++++++++++++++++-- .../converter/InitialPlanProtoVisitor.cpp | 12 +++ .../textplan/converter/PipelineVisitor.cpp | 1 - .../tests/BinaryToTextPlanConversionTest.cpp | 31 ++++++-- src/substrait/textplan/parser/ParseText.cpp | 9 +++ .../parser/SubstraitPlanPipelineVisitor.cpp | 7 ++ .../parser/SubstraitPlanRelationVisitor.cpp | 3 + .../textplan/parser/SubstraitPlanVisitor.cpp | 25 ++++++ .../textplan/parser/SubstraitPlanVisitor.h | 2 + .../parser/grammar/SubstraitPlanLexer.g4 | 1 + .../parser/grammar/SubstraitPlanParser.g4 | 7 ++ .../parser/tests/TextPlanParserTest.cpp | 24 +++++- .../textplan/tests/ParseResultMatchers.cpp | 6 +- 14 files changed, 184 insertions(+), 21 deletions(-) diff --git a/src/substrait/textplan/SymbolTable.h b/src/substrait/textplan/SymbolTable.h index 19e6d419..a0b4ad58 100644 --- a/src/substrait/textplan/SymbolTable.h +++ b/src/substrait/textplan/SymbolTable.h @@ -25,6 +25,7 @@ enum class SymbolType { kSource = 7, kSourceDetail = 8, kField = 9, + kRoot = 10, kUnknown = -1, }; diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 121274a6..bfa68709 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -158,6 +158,44 @@ std::string outputRelationsSection(const SymbolTable& symbolTable) { return text.str(); } +std::string outputRootSection(const SymbolTable& symbolTable) { + std::stringstream text; + bool hasPreviousText = false; + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kRoot) { + continue; + } + auto names = ANY_CAST(std::vector, info.blob); + if (names.empty()) { + // No point in printing an empty section. + continue; + } + if (hasPreviousText) { + text << "\n"; + } + text << "root {" + << "\n"; + text << " names = ["; + bool hadName = false; + for (const auto& name : names) { + if (hadName) { + text << ",\n"; + } else { + text << "\n"; + } + text << " " << name; + hadName = true; + } + if (hadName) { + text << "\n"; + } + text << " ]\n"; + text << "}\n"; + hasPreviousText = true; + } + return text.str(); +} + std::string outputSchemaSection(const SymbolTable& symbolTable) { std::stringstream text; bool hasPreviousText = false; @@ -427,6 +465,15 @@ std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { hasPreviousText = true; } + newText = outputRootSection(symbolTable); + if (!newText.empty()) { + if (hasPreviousText) { + text << "\n"; + } + text << newText; + hasPreviousText = true; + } + newText = outputSchemaSection(symbolTable); if (!newText.empty()) { if (hasPreviousText) { @@ -668,14 +715,27 @@ ::substrait::proto::Plan SymbolTablePrinter::outputToBinaryPlan( if (relationData->newPipelines.empty()) { *relation->mutable_root()->mutable_input() = relationData->relation; } else { - // This is a root node, copy the first node in before iterating. - auto inputRelationData = ANY_CAST( - std::shared_ptr, relationData->newPipelines[0]->blob); - *relation->mutable_root()->mutable_input() = inputRelationData->relation; - - addInputsToRelation( - *relationData->newPipelines[0], - relation->mutable_root()->mutable_input()); + if (relationData->newPipelines[0]->type != SymbolType::kRoot) { + // This is a root node, copy the first node in before iterating. + auto inputRelationData = ANY_CAST( + std::shared_ptr, relationData->newPipelines[0]->blob); + *relation->mutable_root()->mutable_input() = + inputRelationData->relation; + + addInputsToRelation( + *relationData->newPipelines[0], + relation->mutable_root()->mutable_input()); + } + + const auto& rootSymbol = + symbolTable.nthSymbolByType(0, SymbolType::kRoot); + if (rootSymbol != SymbolInfo::kUnknown) { + const auto& rootNames = + ANY_CAST(std::vector, rootSymbol.blob); + for (const auto& name : rootNames) { + relation->mutable_root()->add_names(name); + } + } } } diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index 53a5ce24..fc0f493b 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -19,6 +19,8 @@ namespace io::substrait::textplan { namespace { +const std::string kRootNames{"root.names"}; + std::string shortName(std::string str) { auto loc = str.find(':'); if (loc != std::string::npos) { @@ -167,6 +169,16 @@ std::any InitialPlanProtoVisitor::visitRelation( std::any InitialPlanProtoVisitor::visitRelationRoot( const ::substrait::proto::RelRoot& relation) { + std::vector names; + names.insert(names.end(), relation.names().begin(), relation.names().end()); + auto uniqueName = symbolTable_->getUniqueName(kRootNames); + symbolTable_->defineSymbol( + uniqueName, + PROTO_LOCATION(relation), + SymbolType::kRoot, + SourceType::kUnknown, + names); + BasePlanProtoVisitor::visitRelationRoot(relation); return std::nullopt; } diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index da5cd87f..69909f06 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -19,7 +19,6 @@ std::shared_ptr PipelineVisitor::getRelationData( std::any PipelineVisitor::visitRelation( const ::substrait::proto::Rel& relation) { auto relationData = getRelationData(relation); - switch (relation.rel_type_case()) { case ::substrait::proto::Rel::RelTypeCase::kRead: // No relations beyond this one. diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index c0ebbdcc..f15796d1 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -148,7 +148,7 @@ std::vector getTestCases() { } })", AllOf( - HasSymbols({"local", "read", "root"}), + HasSymbols({"root.names", "local", "read", "root"}), WhenSerialized(EqSquashingWhitespace( R"(pipelines { read -> root; @@ -190,7 +190,14 @@ std::vector getTestCases() { })", AllOf( HasSymbols( - {"schema", "cost", "count", "named", "#2", "read", "root"}), + {"root.names", + "schema", + "cost", + "count", + "named", + "#2", + "read", + "root"}), WhenSerialized(EqSquashingWhitespace( R"(pipelines { read -> root; @@ -401,7 +408,7 @@ std::vector getTestCases() { } })", AllOf( - HasSymbols({"filter", "root"}), + HasSymbols({"root.names", "filter", "root"}), WhenSerialized(EqSquashingWhitespace( R"(pipelines { filter -> root; @@ -454,7 +461,7 @@ std::vector getTestCases() { } })", AllOf( - HasSymbols({"filter", "root"}), + HasSymbols({"root.names", "filter", "root"}), WhenSerialized(EqSquashingWhitespace( R"(pipelines { filter -> root; @@ -526,7 +533,7 @@ std::vector getTestCases() { } })", AllOf( - HasSymbols({"filter", "root"}), + HasSymbols({"root.names", "filter", "root"}), WhenSerialized(EqSquashingWhitespace( R"(pipelines { filter -> root; @@ -539,7 +546,7 @@ std::vector getTestCases() { { "single three node pipeline", "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }", - HasSymbols({"local", "read", "project", "root"}), + HasSymbols({"root.names", "local", "read", "project", "root"}), }, { "two identical three node pipelines", @@ -547,10 +554,12 @@ std::vector getTestCases() { "relations: { root: { input: { project: { input { read: { local_files {} } } } } } }", AllOf( HasSymbols( - {"local", + {"root.names", + "local", "read", "project", "root", + "root.names2", "local2", "read2", "project2", @@ -566,7 +575,13 @@ std::vector getTestCases() { "relations: { root: { input: { hash_join: { left { read: { local_files {} } } right { read: { local_files {} } } } } } }", AllOf( HasSymbols( - {"local", "read", "local2", "read2", "hashjoin", "root"}), + {"root.names", + "local", + "read", + "local2", + "read2", + "hashjoin", + "root"}), WhenSerialized(::testing::HasSubstr("pipelines {\n" " read -> hashjoin;\n" " read2 -> hashjoin;\n" diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 554e4a26..3a19d3cc 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -71,6 +71,15 @@ ParseResult parseStream(antlr4::ANTLRInputStream stream) { *visitor->getSymbolTable(), visitor->getErrorListener()); try { pipelineVisitor->visitPlan(tree); + } catch (std::invalid_argument ex) { + // Catches the any_cast exception and logs a useful error message. + errorListener.syntaxError( + &parser, + nullptr, + /*line=*/1, + /*charPositionInLine=*/1, + ex.what(), + std::current_exception()); } catch (...) { errorListener.syntaxError( &parser, diff --git a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp index a08c682e..806d23b1 100644 --- a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp @@ -77,6 +77,9 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( // Refetch our symbol table entry to make sure we have the latest version. auto* symbol = symbolTable_->lookupSymbolByName(relationName); + if (symbol->blob.type() != typeid(std::shared_ptr)) { + return defaultResult(); + } auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); // Check for accidental cross-pipeline use. @@ -103,6 +106,10 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( } const SymbolInfo* rightmostSymbol = rightSymbol; if (*rightSymbol != SymbolInfo::kUnknown) { + if (rightSymbol->blob.type() != typeid(std::shared_ptr)) { + errorListener_->addError( + ctx->getStart(), "No relation definition present for this symbol."); + } auto rightRelationData = ANY_CAST(std::shared_ptr, rightSymbol->blob); if (rightRelationData->pipelineStart != nullptr) { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index 471c1b29..9ed04f84 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -258,6 +258,9 @@ std::any SubstraitPlanRelationVisitor::visitRelation( // This error has been previously dealt with thus we can safely skip it. return defaultResult(); } + if (symbol->type == SymbolType::kRoot) { + return defaultResult(); + } auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); ::substrait::proto::Rel relation; diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp index 33657ad3..f57d52eb 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp @@ -9,10 +9,13 @@ #include "substrait/textplan/Finally.h" #include "substrait/textplan/Location.h" #include "substrait/textplan/StructuredSymbolData.h" +#include "substrait/textplan/SymbolTable.h" #include "substrait/type/Type.h" namespace io::substrait::textplan { +const std::string kRootName{"root"}; + // Removes leading and trailing quotation marks. std::string extractFromString(std::string s) { if (s.size() < 2) { @@ -158,6 +161,28 @@ std::any SubstraitPlanVisitor::visitSchema_item( visitLiteral_complex_type(ctx->literal_complex_type())); } +std::any SubstraitPlanVisitor::visitRoot_relation( + SubstraitPlanParser::Root_relationContext* ctx) { + auto prevRoot = symbolTable_->lookupSymbolByName(kRootName); + if (prevRoot != nullptr) { + if (prevRoot->type == SymbolType::kRoot) { + errorListener_->addError( + ctx->getStart(), "A root relation was already defined."); + } else { + errorListener_->addError( + ctx->getStart(), "A relation named root was already defined."); + } + return nullptr; + } + std::vector names; + for (const auto& id : ctx->id()) { + names.push_back(id->getText()); + } + symbolTable_->defineSymbol( + kRootName, Location(ctx), SymbolType::kRoot, SourceType::kUnknown, names); + return nullptr; +} + std::any SubstraitPlanVisitor::visitRelation( SubstraitPlanParser::RelationContext* ctx) { auto relType = diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.h b/src/substrait/textplan/parser/SubstraitPlanVisitor.h index 84f36b3e..94db76ee 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.h @@ -40,6 +40,8 @@ class SubstraitPlanVisitor : public SubstraitPlanTypeVisitor { std::any visitSchema_item( SubstraitPlanParser::Schema_itemContext* ctx) override; std::any visitRelation(SubstraitPlanParser::RelationContext* ctx) override; + std::any visitRoot_relation( + SubstraitPlanParser::Root_relationContext* ctx) override; std::any visitRelation_type( SubstraitPlanParser::Relation_typeContext* ctx) override; std::any visitSource_definition( diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 index e96edf5f..1ad91127 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 @@ -50,6 +50,7 @@ NAMED_TABLE: 'NAMED_TABLE'; EXTENSION_TABLE: 'EXTENSION_TABLE'; SOURCE: 'SOURCE'; +ROOT: 'ROOT'; ITEMS: 'ITEMS'; NAMES: 'NAMES'; URI_FILE: 'URI_FILE'; diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 index f5520241..145d7262 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 @@ -23,6 +23,7 @@ plan plan_detail : pipelines | relation + | root_relation | schema_definition | source_definition | extensionspace @@ -42,6 +43,10 @@ relation : relation_type RELATION relation_ref LEFTBRACE relation_detail* RIGHTBRACE ; +root_relation + : ROOT LEFTBRACE NAMES EQUAL LEFTBRACKET id (COMMA id)* COMMA? RIGHTBRACKET RIGHTBRACE + ; + relation_type : id ; @@ -209,6 +214,8 @@ id simple_id : IDENTIFIER | FILTER + | ROOT + | SOURCE | SCHEMA | NULLVAL | SORT diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index d6c0f169..e3fc2f68 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -1001,7 +1001,7 @@ std::vector getTestCases() { "1:0 → extraneous input 'relation' expecting {, " "'EXTENSION_SPACE', 'SCHEMA', 'PIPELINES', 'FILTER', " "'GROUPING', 'MEASURE', 'SORT', 'COUNT', 'TYPE', 'SOURCE', " - "'NULL', IDENTIFIER}", + "'ROOT', 'NULL', IDENTIFIER}", "1:24 → mismatched input '{' expecting 'RELATION'", "1:9 → Unrecognized relation type: notyperelation", }), @@ -1030,6 +1030,28 @@ std::vector getTestCases() { HasSymbolsWithTypes( {"read", "project", "root"}, {SymbolType::kRelation}), ParsesOk()), + + }, + { + "test18-root-and-read", + R"(pipelines { + root -> read; + } + + read relation read { + base_schema schemaone; + source mynamedtable; + } + + root { + names = [ + apple, + ] + })", + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( + R"(relations: { + root { names: "apple" } + })")), }, }; return cases; diff --git a/src/substrait/textplan/tests/ParseResultMatchers.cpp b/src/substrait/textplan/tests/ParseResultMatchers.cpp index 69f034ff..b83baca1 100644 --- a/src/substrait/textplan/tests/ParseResultMatchers.cpp +++ b/src/substrait/textplan/tests/ParseResultMatchers.cpp @@ -160,7 +160,7 @@ class HasSymbolsMatcher { extraSymbols.begin()); extraSymbols.resize(end - extraSymbols.begin()); if (!extraSymbols.empty()) { - *listener << std::endl << " with missing symbols: "; + *listener << std::endl << " with extra symbols: "; for (const auto& symbol : extraSymbols) { *listener << " \"" << symbol << "\""; } @@ -176,9 +176,9 @@ class HasSymbolsMatcher { missingSymbols.resize(end - missingSymbols.begin()); if (!missingSymbols.empty()) { if (!extraSymbols.empty()) { - *listener << ", and extra symbols: "; + *listener << ", and missing symbols: "; } else { - *listener << " with extra symbols: "; + *listener << " with missing symbols: "; } for (const auto& symbol : missingSymbols) { *listener << " \"" << symbol << "\"";