diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 2faa1c85..bfa68709 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -715,14 +715,17 @@ ::substrait::proto::Plan SymbolTablePrinter::outputToBinaryPlan( if (relationData->newPipelines.empty()) { *relation->mutable_root()->mutable_input() = relationData->relation; } else { - // This is a root node, copy the first node in before iterating. - auto inputRelationData = ANY_CAST( - std::shared_ptr, relationData->newPipelines[0]->blob); - *relation->mutable_root()->mutable_input() = inputRelationData->relation; - - addInputsToRelation( - *relationData->newPipelines[0], - relation->mutable_root()->mutable_input()); + if (relationData->newPipelines[0]->type != SymbolType::kRoot) { + // This is a root node, copy the first node in before iterating. + auto inputRelationData = ANY_CAST( + std::shared_ptr, relationData->newPipelines[0]->blob); + *relation->mutable_root()->mutable_input() = + inputRelationData->relation; + + addInputsToRelation( + *relationData->newPipelines[0], + relation->mutable_root()->mutable_input()); + } const auto& rootSymbol = symbolTable.nthSymbolByType(0, SymbolType::kRoot); diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 554e4a26..3a19d3cc 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -71,6 +71,15 @@ ParseResult parseStream(antlr4::ANTLRInputStream stream) { *visitor->getSymbolTable(), visitor->getErrorListener()); try { pipelineVisitor->visitPlan(tree); + } catch (std::invalid_argument ex) { + // Catches the any_cast exception and logs a useful error message. + errorListener.syntaxError( + &parser, + nullptr, + /*line=*/1, + /*charPositionInLine=*/1, + ex.what(), + std::current_exception()); } catch (...) { errorListener.syntaxError( &parser, diff --git a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp index a08c682e..eeea37bf 100644 --- a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp @@ -77,6 +77,9 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( // Refetch our symbol table entry to make sure we have the latest version. auto* symbol = symbolTable_->lookupSymbolByName(relationName); + if (symbol->blob.type() != typeid(std::shared_ptr)) { + return defaultResult(); + } auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); // Check for accidental cross-pipeline use. @@ -103,6 +106,9 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( } const SymbolInfo* rightmostSymbol = rightSymbol; if (*rightSymbol != SymbolInfo::kUnknown) { + if (rightSymbol->blob.type() != typeid(std::shared_ptr)) { +errorListener_->addError(ctx->getStart(), "blah"); + } auto rightRelationData = ANY_CAST(std::shared_ptr, rightSymbol->blob); if (rightRelationData->pipelineStart != nullptr) { diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index 25e3baa6..e3fc2f68 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -1030,6 +1030,28 @@ std::vector getTestCases() { HasSymbolsWithTypes( {"read", "project", "root"}, {SymbolType::kRelation}), ParsesOk()), + + }, + { + "test18-root-and-read", + R"(pipelines { + root -> read; + } + + read relation read { + base_schema schemaone; + source mynamedtable; + } + + root { + names = [ + apple, + ] + })", + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( + R"(relations: { + root { names: "apple" } + })")), }, }; return cases;