From 8a62ae600cf130bd64f8cd743aa13ed9708e443d Mon Sep 17 00:00:00 2001 From: David Sisson Date: Fri, 6 Sep 2024 20:02:53 -0700 Subject: [PATCH] feat: implement emit for aggregate relations --- src/substrait/textplan/PlanPrinterVisitor.cpp | 75 +++++++++++- src/substrait/textplan/PlanPrinterVisitor.h | 7 ++ src/substrait/textplan/SymbolTablePrinter.cpp | 2 +- .../converter/InitialPlanProtoVisitor.cpp | 20 ++- .../tests/BinaryToTextPlanConversionTest.cpp | 54 +++++++++ .../parser/SubstraitPlanRelationVisitor.cpp | 27 ++++- .../parser/SubstraitPlanRelationVisitor.h | 1 + .../SubstraitPlanSubqueryRelationVisitor.cpp | 29 ++++- .../SubstraitPlanSubqueryRelationVisitor.h | 1 + .../textplan/parser/SubstraitPlanVisitor.cpp | 2 + .../parser/grammar/SubstraitPlanLexer.g4 | 1 + .../parser/grammar/SubstraitPlanParser.g4 | 1 + .../parser/tests/TextPlanParserTest.cpp | 114 ++++++++++++++++++ 13 files changed, 320 insertions(+), 14 deletions(-) diff --git a/src/substrait/textplan/PlanPrinterVisitor.cpp b/src/substrait/textplan/PlanPrinterVisitor.cpp index 91585517..68a1eb8d 100644 --- a/src/substrait/textplan/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/PlanPrinterVisitor.cpp @@ -71,6 +71,19 @@ std::string visitEnumArgument(const std::string& str) { return text.str(); } +bool isAggregate(const SymbolInfo* symbol) { + if (symbol->subtype.type() == typeid(::substrait::proto::Rel::RelTypeCase) && + ANY_CAST(::substrait::proto::Rel::RelTypeCase, symbol->subtype) == + ::substrait::proto::Rel::kAggregate) { + return true; + } + if (symbol->subtype.type() == typeid(RelationType) && + ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate) { + return true; + } + return false; +} + } // namespace std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { @@ -177,6 +190,65 @@ std::string PlanPrinterVisitor::lookupFieldReference( return symbol->name; } +std::string PlanPrinterVisitor::lookupFieldReferenceForEmit( + uint32_t fieldReference, + const SymbolInfo* currentScope, + uint32_t stepsOut, + bool needFullyQualified) { + if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) { + errorListener_->addError( + "Field number " + std::to_string(fieldReference) + + " mysteriously requested outside of a relation."); + return "field#" + std::to_string(fieldReference); + } + auto actualScope = currentScope; + if (stepsOut > 0) { + for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) { + auto actualParentQueryLocation = getParentQueryLocation(actualScope); + if (actualParentQueryLocation == Location::kUnknownLocation) { + errorListener_->addError( + "Requested steps out of " + std::to_string(stepsOut) + + " but not within subquery depth that high."); + return "field#" + std::to_string(fieldReference); + } + actualScope = symbolTable_->lookupSymbolByLocationAndType( + actualParentQueryLocation, SymbolType::kRelation); + if (actualScope == nullptr) { + errorListener_->addError( + "Internal error: Missing previously encountered parent query symbol."); + return "field#" + std::to_string(fieldReference); + } + } + } + auto relationData = + ANY_CAST(std::shared_ptr, actualScope->blob); + const SymbolInfo* symbol{nullptr}; + auto fieldReferencesSize = relationData->fieldReferences.size(); + if (isAggregate(currentScope) && + fieldReference < relationData->generatedFieldReferences.size()) { + symbol = relationData->generatedFieldReferences[fieldReference]; + } else if (fieldReference < fieldReferencesSize) { + symbol = relationData->fieldReferences[fieldReference]; + } else if ( + fieldReference < + fieldReferencesSize + relationData->generatedFieldReferences.size()) { + symbol = + relationData + ->generatedFieldReferences[fieldReference - fieldReferencesSize]; + } else { + errorListener_->addError( + "Encountered field reference out of range: " + + std::to_string(fieldReference)); + return "field#" + std::to_string(fieldReference); + } + if (!symbol->alias.empty()) { + return symbol->alias; + } else if (needFullyQualified && symbol->schema != nullptr) { + return symbol->schema->name + "." + symbol->name; + } + return symbol->name; +} + std::string PlanPrinterVisitor::lookupFunctionReference( uint32_t function_reference) { for (const auto& symbol : symbolTable_->getSymbols()) { @@ -813,7 +885,7 @@ std::any PlanPrinterVisitor::visitRelationCommon( } for (const auto& mapping : common.emit().output_mapping()) { text << " emit " - << lookupFieldReference( + << lookupFieldReferenceForEmit( mapping, currentScope_, /* stepsOut= */ 0, true) << ";\n"; } @@ -1028,6 +1100,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation( } text << " }\n"; } + text << ANY_CAST(std::string, visitRelationCommon(relation.common())); return text.str(); } diff --git a/src/substrait/textplan/PlanPrinterVisitor.h b/src/substrait/textplan/PlanPrinterVisitor.h index 2b598b5b..f408bdac 100644 --- a/src/substrait/textplan/PlanPrinterVisitor.h +++ b/src/substrait/textplan/PlanPrinterVisitor.h @@ -39,6 +39,13 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { const SymbolInfo* currentScope, uint32_t stepsOut, bool needFullyQualified); + + std::string lookupFieldReferenceForEmit( + uint32_t fieldReference, + const SymbolInfo* currentScope, + uint32_t stepsOut, + bool needFullyQualified); + std::string lookupFunctionReference(uint32_t function_reference); std::any visitSubqueryScalar( diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index cdd3e8c9..13bf5341 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -394,7 +394,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { functionsToOutput.emplace_back(info.name, functionData->name); } std::sort(functionsToOutput.begin(), functionsToOutput.end()); - for (auto [shortName, canonicalName] : functionsToOutput) { + for (const auto& [shortName, canonicalName] : functionsToOutput) { text << " function " << canonicalName << " as " << shortName << ";\n"; } text << "}\n"; diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index d044acc6..4ee3a735 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema( std::nullopt); relationData->generatedFieldReferences.emplace_back(symbol); } + // TODO -- If there are multiple groupings add the additional output. // Aggregate relations are different in that they alter the emitted fields // by default. relationData->outputFieldReferences.insert( @@ -629,9 +630,24 @@ void InitialPlanProtoVisitor::updateLocalSchema( // Revamp the output based on the output mapping if present. auto mapping = getOutputMapping(relation); if (!mapping.empty()) { + // TODO -- Use a more explicit check. if (!relationData->outputFieldReferences.empty()) { - errorListener_->addError( - "Aggregate relations do not yet support output mapping changes."); + // We are processing an aggregate node which is the only relation with + // output field references. + auto generatedFieldReferenceSize = + relationData->generatedFieldReferences.size(); + relationData->outputFieldReferences.clear(); // Start over. + for (auto item : mapping) { + if (item < generatedFieldReferenceSize) { + relationData->outputFieldReferences.push_back( + relationData->generatedFieldReferences[item]); + } else { + // TODO -- Add support for grouping fields (needs text syntax). + errorListener_->addError( + "Asked to emit a field (" + std::to_string(item) + + " beyond what the aggregate produced."); + } + } return; } for (auto item : mapping) { diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index 3b0ae97f..de45c0e5 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -590,6 +590,60 @@ std::vector getTestCases() { " hashjoin -> root;\n" "}\n"))), }, + { + "aggregate with emits", + R"(extensions { + extension_function { + function_anchor: 0 + name: "sum:i32" + } + } + relations: { root: { input: { + aggregate: { common { emit { output_mapping: 1 } } input { + read: { base_schema { names: 'a' names: 'b' + struct { types { string {} } types { i32 {} } } } + local_files { items { uri_file: 'x.parquet' parquet { } } } + } } + measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } } + measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 0 } } } } } } } + } } } })", + AllOf( + WhenSerialized(EqSquashingWhitespace( + "pipelines {\n" + " read -> aggregate -> root;\n" + "}\n" + "\n" + "read relation read {\n" + " source local;\n" + " base_schema schema;\n" + "}\n" + "\n" + "aggregate relation aggregate {\n" + " measure {\n" + " measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename;\n" + " }\n" + " measure {\n" + " measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2;\n" + " }\n" + "\n" + " emit measurename2;\n" + "}\n" + "\n" + "schema schema {\n" + " a string;\n" + " b i32;\n" + "}\n" + "\n" + "source local_files local {\n" + " items = [\n" + " {uri_file: \"x.parquet\" parquet: {}}\n" + " ]\n" + "}\n" + "\n" + "extension_space {\n" + " function sum:i32 as sum;\n" + "}\n"))), + }, }; return cases; } diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index 51525404..dd080923 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -316,6 +316,11 @@ bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) { nullptr; } +bool isAggregate(const SymbolInfo* symbol) { + return symbol->subtype.type() == typeid(RelationType) && + ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate; +} + } // namespace std::any SubstraitPlanRelationVisitor::aggregateResult( @@ -819,7 +824,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationEmit( SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); + this->processingEmit = true; auto result = visitChildren(ctx); + this->processingEmit = false; auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); auto common = findCommonRelation(parentRelationType, &parentRelationData->relation); @@ -2023,6 +2030,9 @@ std::pair SubstraitPlanRelationVisitor::findFieldReferenceByName( std::shared_ptr& relationData, const std::string& name) { auto fieldReferencesSize = relationData->fieldReferences.size(); + if (isAggregate(symbol) && this->processingEmit) { + fieldReferencesSize = 0; + } auto generatedField = std::find_if( relationData->generatedFieldReferences.rbegin(), @@ -2075,10 +2085,19 @@ void SubstraitPlanRelationVisitor::applyOutputMappingToSchema( if (common->emit().output_mapping_size() == 0) { common->mutable_direct(); } else { - if (!relationData->outputFieldReferences.empty()) { - // TODO -- Add support for aggregate relations. - errorListener_->addError( - token, "Aggregate relations do not yet support emit sections."); + if (relationData->relation.has_aggregate()) { + auto oldReferences = relationData->outputFieldReferences; + relationData->outputFieldReferences.clear(); + for (auto mapping : common->emit().output_mapping()) { + if (mapping < oldReferences.size()) { + relationData->outputFieldReferences.push_back(oldReferences[mapping]); + } else { + errorListener_->addError( + token, + "Field #" + std::to_string(mapping) + " requested but only " + + std::to_string(oldReferences.size()) + " are available."); + } + } return; } for (auto mapping : common->emit().output_mapping()) { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h index b7ff411d..c2ed07d9 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h @@ -205,6 +205,7 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { bool hasSubquery(SubstraitPlanParser::ExpressionContext* ctx); const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. + bool processingEmit{false}; }; } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp index 29e91976..975b3f8d 100644 --- a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp @@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation"; const std::string kJoinTypePrefix = "jointype"; const std::string kSortDirectionPrefix = "sortdirection"; -const std::string kIntermediateNodeName = "intermediate"; - enum RelationFilterBehavior { kDefault = 0, kBestEffort = 1, @@ -374,6 +372,11 @@ comparisonToProto(const std::string& text) { Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED; } +bool isAggregate(const SymbolInfo* symbol) { + return symbol->subtype.type() == typeid(RelationType) && + ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate; +} + } // namespace std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult( @@ -871,7 +874,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit( SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); + this->processingEmit = true; auto result = visitChildren(ctx); + this->processingEmit = false; auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); auto common = findCommonRelation(parentRelationType, &parentRelationData->relation); @@ -2163,6 +2168,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName( std::shared_ptr& relationData, const std::string& name) { auto fieldReferencesSize = relationData->fieldReferences.size(); + if (isAggregate(symbol) && this->processingEmit) { + fieldReferencesSize = 0; + } auto generatedField = std::find_if( relationData->generatedFieldReferences.rbegin(), @@ -2234,10 +2242,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema( if (common->emit().output_mapping_size() == 0) { common->mutable_direct(); } else { - if (!relationData->outputFieldReferences.empty()) { - // TODO -- Add support for aggregate relations. - errorListener_->addError( - token, "Aggregate relations do not yet support emit sections."); + if (relationData->relation.has_aggregate()) { + auto oldReferences = relationData->outputFieldReferences; + relationData->outputFieldReferences.clear(); + for (auto mapping : common->emit().output_mapping()) { + if (mapping < oldReferences.size()) { + relationData->outputFieldReferences.push_back(oldReferences[mapping]); + } else { + errorListener_->addError( + token, + "Field #" + std::to_string(mapping) + " requested but only " + + std::to_string(oldReferences.size()) + " are available."); + } + } return; } for (auto mapping : common->emit().output_mapping()) { diff --git a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h index 40971769..c21665b0 100644 --- a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h @@ -217,6 +217,7 @@ class SubstraitPlanSubqueryRelationVisitor : public SubstraitPlanTypeVisitor { bool isWithinSubquery(SubstraitPlanParser::RelationContext* ctx); const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. + bool processingEmit{false}; }; } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp index 9c14f828..2cf57933 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp @@ -502,6 +502,8 @@ std::any SubstraitPlanVisitor::visitFile_detail( item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText())); } else if (ctx->ORC() != nullptr) { item->mutable_orc(); + } else if (ctx->PARQUET() != nullptr) { + item->mutable_parquet(); } else { return visitChildren(ctx); } diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 index 04371ab0..ccd5dbd0 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 @@ -71,6 +71,7 @@ PARTITION_INDEX: 'PARTITION_INDEX'; START: 'START'; LENGTH: 'LENGTH'; ORC: 'ORC'; +PARQUET: 'PARQUET'; NULLVAL: 'NULL'; TRUEVAL: 'TRUE'; FALSEVAL: 'FALSE'; diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 index 791c5d6d..f76d8afb 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 @@ -159,6 +159,7 @@ file_detail | START COLON NUMBER | LENGTH COLON NUMBER | ORC COLON LEFTBRACE RIGHTBRACE + | PARQUET COLON LEFTBRACE RIGHTBRACE | file_location ; diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index ddfe8094..461dbd64 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -1183,6 +1183,120 @@ std::vector getTestCases() { } })"))), }, + { + "test20-aggregation-emit", + R"(pipelines { + read -> aggregate -> root; + } + + read relation read { + source local; + base_schema schema; + } + + aggregate relation aggregate { + measure { + measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename; + } + measure { + measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2; + } + + emit measurename2; + } + + schema schema { + a string; + b i32; + } + + source local_files local { + items = [ + { uri_file: "x.parquet" parquet: {} } + ] + } + + extension_space { + function sum:i32 as sum; + })", + AllOf( + HasSymbolsWithTypes( + {"read", "aggregate", "root"}, {SymbolType::kRelation}), + HasErrors({}), + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>(R"( + extensions { + extension_function { + name: "sum:i32" + } + } + relations { + root { + input { + aggregate { + common { + emit { + output_mapping: 1 + } + } + input { + read { + common { + direct { } + } + base_schema { + names: 'a' names: 'b' + struct { + types { + string { nullability: NULLABILITY_REQUIRED } + } + types { + i32 { nullability: NULLABILITY_REQUIRED } + } + nullability: NULLABILITY_REQUIRED + } + } + local_files { + items { uri_file: 'x.parquet' parquet { } } + } + } + } + groupings { + } + measures { + measure { + output_type { + i32 { nullability: NULLABILITY_REQUIRED } + } + arguments { + value { + selection { + direct_reference { struct_field { field: 1 } } + root_reference: { } + } + } + } + } + } + measures { + measure { + output_type { + i32 { nullability: NULLABILITY_REQUIRED } + } + arguments { + value { + selection { + direct_reference { struct_field { field: 0 } } + root_reference: { } + } + } + } + } + } + } + } + } + })"))), + }, }; return cases; }