diff --git a/src/substrait/textplan/Any.h b/src/substrait/textplan/Any.h index 792e178d..4ce97c94 100644 --- a/src/substrait/textplan/Any.h +++ b/src/substrait/textplan/Any.h @@ -23,4 +23,10 @@ any_cast(const std::any& value, const char* file, int line) { // NOLINT #define ANY_CAST(ValueType, Value) \ ::io::substrait::textplan::any_cast(Value, __FILE__, __LINE__) +// Casts the any if it matches the given type otherwise it returns nullopt. +#define ANY_CAST_IF(ValueType, value) \ + value.type() != typeid(ValueType) \ + ? ::std::nullopt \ + : ::std::make_optional(ANY_CAST(ValueType, value)) + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/PlanPrinterVisitor.cpp b/src/substrait/textplan/PlanPrinterVisitor.cpp index 91585517..f2b12206 100644 --- a/src/substrait/textplan/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/PlanPrinterVisitor.cpp @@ -71,6 +71,18 @@ std::string visitEnumArgument(const std::string& str) { return text.str(); } +bool isAggregate(const SymbolInfo* symbol) { + // TODO: Remove after the relation type is one type internally. + if (const auto typeCase = + ANY_CAST_IF(::substrait::proto::Rel::RelTypeCase, symbol->subtype)) { + return (typeCase == ::substrait::proto::Rel::kAggregate); + } + if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) { + return (typeCase == RelationType::kAggregate); + } + return false; +} + } // namespace std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { @@ -177,6 +189,71 @@ std::string PlanPrinterVisitor::lookupFieldReference( return symbol->name; } +std::string PlanPrinterVisitor::lookupFieldReferenceForEmit( + uint32_t fieldReference, + const SymbolInfo* currentScope, + uint32_t stepsOut, + bool needFullyQualified) { + if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) { + errorListener_->addError( + "Field number " + std::to_string(fieldReference) + + " mysteriously requested outside of a relation."); + return "field#" + std::to_string(fieldReference); + } + auto actualScope = currentScope; + if (stepsOut > 0) { + for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) { + auto actualParentQueryLocation = getParentQueryLocation(actualScope); + if (actualParentQueryLocation == Location::kUnknownLocation) { + errorListener_->addError( + "Requested field#" + std::to_string(fieldReference) + " at " + + std::to_string(stepsOut) + + " steps out but subquery depth is only " + + std::to_string(stepsLeft)); + return "field#" + std::to_string(fieldReference); + } + actualScope = symbolTable_->lookupSymbolByLocationAndType( + actualParentQueryLocation, SymbolType::kRelation); + if (actualScope == nullptr) { + errorListener_->addError( + "Internal error: Missing previously encountered parent query symbol."); + return "field#" + std::to_string(fieldReference); + } + } + } + auto relationData = + ANY_CAST(std::shared_ptr, actualScope->blob); + const SymbolInfo* symbol{nullptr}; + const char* relationType = "non-aggregate"; + if (isAggregate(currentScope)) { + relationType = "aggregate"; + if (fieldReference < relationData->generatedFieldReferences.size()) { + symbol = relationData->generatedFieldReferences[fieldReference]; + } + } else { + auto size = relationData->fieldReferences.size(); + if (fieldReference < size) { + symbol = relationData->fieldReferences[fieldReference]; + } else if ( + fieldReference < size + relationData->generatedFieldReferences.size()) { + symbol = relationData->generatedFieldReferences[fieldReference - size]; + } + } + if (symbol == nullptr) { + errorListener_->addError( + "Encountered field reference out of range in " + + std::string(relationType) + + " relation: " + std::to_string(fieldReference)); + return "field#" + std::to_string(fieldReference); + } + if (!symbol->alias.empty()) { + return symbol->alias; + } else if (needFullyQualified && symbol->schema != nullptr) { + return symbol->schema->name + "." + symbol->name; + } + return symbol->name; +} + std::string PlanPrinterVisitor::lookupFunctionReference( uint32_t function_reference) { for (const auto& symbol : symbolTable_->getSymbols()) { @@ -813,7 +890,7 @@ std::any PlanPrinterVisitor::visitRelationCommon( } for (const auto& mapping : common.emit().output_mapping()) { text << " emit " - << lookupFieldReference( + << lookupFieldReferenceForEmit( mapping, currentScope_, /* stepsOut= */ 0, true) << ";\n"; } @@ -1028,6 +1105,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation( } text << " }\n"; } + text << ANY_CAST(std::string, visitRelationCommon(relation.common())); return text.str(); } diff --git a/src/substrait/textplan/PlanPrinterVisitor.h b/src/substrait/textplan/PlanPrinterVisitor.h index 2b598b5b..f408bdac 100644 --- a/src/substrait/textplan/PlanPrinterVisitor.h +++ b/src/substrait/textplan/PlanPrinterVisitor.h @@ -39,6 +39,13 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { const SymbolInfo* currentScope, uint32_t stepsOut, bool needFullyQualified); + + std::string lookupFieldReferenceForEmit( + uint32_t fieldReference, + const SymbolInfo* currentScope, + uint32_t stepsOut, + bool needFullyQualified); + std::string lookupFunctionReference(uint32_t function_reference); std::any visitSubqueryScalar( diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index cdd3e8c9..13bf5341 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -394,7 +394,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { functionsToOutput.emplace_back(info.name, functionData->name); } std::sort(functionsToOutput.begin(), functionsToOutput.end()); - for (auto [shortName, canonicalName] : functionsToOutput) { + for (const auto& [shortName, canonicalName] : functionsToOutput) { text << " function " << canonicalName << " as " << shortName << ";\n"; } text << "}\n"; diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index d044acc6..16cb80f0 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema( std::nullopt); relationData->generatedFieldReferences.emplace_back(symbol); } + // TODO -- If there are multiple groupings add the additional output. // Aggregate relations are different in that they alter the emitted fields // by default. relationData->outputFieldReferences.insert( @@ -629,9 +630,21 @@ void InitialPlanProtoVisitor::updateLocalSchema( // Revamp the output based on the output mapping if present. auto mapping = getOutputMapping(relation); if (!mapping.empty()) { - if (!relationData->outputFieldReferences.empty()) { - errorListener_->addError( - "Aggregate relations do not yet support output mapping changes."); + if (relation.rel_type_case() == ::substrait::proto::Rel::kAggregate) { + auto generatedFieldReferenceSize = + relationData->generatedFieldReferences.size(); + relationData->outputFieldReferences.clear(); // Start over. + for (auto item : mapping) { + if (item < generatedFieldReferenceSize) { + relationData->outputFieldReferences.push_back( + relationData->generatedFieldReferences[item]); + } else { + // TODO -- Add support for grouping fields (needs text syntax). + errorListener_->addError( + "Asked to emit a field (" + std::to_string(item) + + ") beyond what the aggregate produced."); + } + } return; } for (auto item : mapping) { diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index 3b0ae97f..935143e8 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -590,6 +590,59 @@ std::vector getTestCases() { " hashjoin -> root;\n" "}\n"))), }, + { + "aggregate with emits", + R"(extensions { + extension_function { + function_anchor: 0 + name: "sum:i32" + } + } + relations: { root: { input: { + aggregate: { common { emit { output_mapping: 1 } } input { + read: { base_schema { names: 'a' names: 'b' + struct { types { string {} } types { i32 {} } } } + local_files { items { uri_file: 'x.parquet' parquet { } } } + } } + measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } } + measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 0 } } } } } } } + } } } })", + AllOf(WhenSerialized(EqSquashingWhitespace( + R"(pipelines { + read -> aggregate -> root; + } + + read relation read { + source local; + base_schema schema; + } + + aggregate relation aggregate { + measure { + measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename; + } + measure { + measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2; + } + + emit measurename2; + } + + schema schema { + a string; + b i32; + } + + source local_files local { + items = [ + {uri_file: "x.parquet" parquet: {}} + ] + } + + extension_space { + function sum:i32 as sum; + )"))), + }, }; return cases; } diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index 51525404..8b3dc157 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -316,6 +316,14 @@ bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) { nullptr; } +bool isAggregate(const SymbolInfo* symbol) { + // TODO: Remove once relation types have a unified type internally. + if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) { + return (typeCase == RelationType::kAggregate); + } + return false; +} + } // namespace std::any SubstraitPlanRelationVisitor::aggregateResult( @@ -819,7 +827,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationEmit( SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); + this->processingEmit_ = true; auto result = visitChildren(ctx); + this->processingEmit_ = false; auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); auto common = findCommonRelation(parentRelationType, &parentRelationData->relation); @@ -2023,6 +2033,9 @@ std::pair SubstraitPlanRelationVisitor::findFieldReferenceByName( std::shared_ptr& relationData, const std::string& name) { auto fieldReferencesSize = relationData->fieldReferences.size(); + if (isAggregate(symbol) && this->processingEmit_) { + fieldReferencesSize = 0; + } auto generatedField = std::find_if( relationData->generatedFieldReferences.rbegin(), @@ -2075,10 +2088,19 @@ void SubstraitPlanRelationVisitor::applyOutputMappingToSchema( if (common->emit().output_mapping_size() == 0) { common->mutable_direct(); } else { - if (!relationData->outputFieldReferences.empty()) { - // TODO -- Add support for aggregate relations. - errorListener_->addError( - token, "Aggregate relations do not yet support emit sections."); + if (relationData->relation.has_aggregate()) { + auto oldReferences = relationData->outputFieldReferences; + relationData->outputFieldReferences.clear(); + for (auto mapping : common->emit().output_mapping()) { + if (mapping < oldReferences.size()) { + relationData->outputFieldReferences.push_back(oldReferences[mapping]); + } else { + errorListener_->addError( + token, + "Field #" + std::to_string(mapping) + " requested but only " + + std::to_string(oldReferences.size()) + " are available."); + } + } return; } for (auto mapping : common->emit().output_mapping()) { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h index b7ff411d..04308668 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h @@ -205,6 +205,7 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { bool hasSubquery(SubstraitPlanParser::ExpressionContext* ctx); const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. + bool processingEmit_{false}; }; } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp index 29e91976..c9a5f8c2 100644 --- a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.cpp @@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation"; const std::string kJoinTypePrefix = "jointype"; const std::string kSortDirectionPrefix = "sortdirection"; -const std::string kIntermediateNodeName = "intermediate"; - enum RelationFilterBehavior { kDefault = 0, kBestEffort = 1, @@ -374,6 +372,14 @@ comparisonToProto(const std::string& text) { Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED; } +bool isAggregate(const SymbolInfo* symbol) { + // TODO: Remove after the relation type is one type internally. + if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) { + return (typeCase == RelationType::kAggregate); + } + return false; +} + } // namespace std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult( @@ -871,7 +877,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit( SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); + this->processingEmit_ = true; auto result = visitChildren(ctx); + this->processingEmit_ = false; auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); auto common = findCommonRelation(parentRelationType, &parentRelationData->relation); @@ -2163,6 +2171,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName( std::shared_ptr& relationData, const std::string& name) { auto fieldReferencesSize = relationData->fieldReferences.size(); + if (isAggregate(symbol) && this->processingEmit_) { + fieldReferencesSize = 0; + } auto generatedField = std::find_if( relationData->generatedFieldReferences.rbegin(), @@ -2234,10 +2245,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema( if (common->emit().output_mapping_size() == 0) { common->mutable_direct(); } else { - if (!relationData->outputFieldReferences.empty()) { - // TODO -- Add support for aggregate relations. - errorListener_->addError( - token, "Aggregate relations do not yet support emit sections."); + if (relationData->relation.has_aggregate()) { + auto oldReferences = relationData->outputFieldReferences; + relationData->outputFieldReferences.clear(); + for (auto mapping : common->emit().output_mapping()) { + if (mapping < oldReferences.size()) { + relationData->outputFieldReferences.push_back(oldReferences[mapping]); + } else { + errorListener_->addError( + token, + "Field #" + std::to_string(mapping) + " requested but only " + + std::to_string(oldReferences.size()) + " are available."); + } + } return; } for (auto mapping : common->emit().output_mapping()) { diff --git a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h index 40971769..825f2a0e 100644 --- a/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanSubqueryRelationVisitor.h @@ -217,6 +217,7 @@ class SubstraitPlanSubqueryRelationVisitor : public SubstraitPlanTypeVisitor { bool isWithinSubquery(SubstraitPlanParser::RelationContext* ctx); const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. + bool processingEmit_{false}; }; } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp index 9c14f828..2cf57933 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp @@ -502,6 +502,8 @@ std::any SubstraitPlanVisitor::visitFile_detail( item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText())); } else if (ctx->ORC() != nullptr) { item->mutable_orc(); + } else if (ctx->PARQUET() != nullptr) { + item->mutable_parquet(); } else { return visitChildren(ctx); } diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 index 04371ab0..ccd5dbd0 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 @@ -71,6 +71,7 @@ PARTITION_INDEX: 'PARTITION_INDEX'; START: 'START'; LENGTH: 'LENGTH'; ORC: 'ORC'; +PARQUET: 'PARQUET'; NULLVAL: 'NULL'; TRUEVAL: 'TRUE'; FALSEVAL: 'FALSE'; diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 index 791c5d6d..f76d8afb 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 @@ -159,6 +159,7 @@ file_detail | START COLON NUMBER | LENGTH COLON NUMBER | ORC COLON LEFTBRACE RIGHTBRACE + | PARQUET COLON LEFTBRACE RIGHTBRACE | file_location ; diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index ddfe8094..461dbd64 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -1183,6 +1183,120 @@ std::vector getTestCases() { } })"))), }, + { + "test20-aggregation-emit", + R"(pipelines { + read -> aggregate -> root; + } + + read relation read { + source local; + base_schema schema; + } + + aggregate relation aggregate { + measure { + measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename; + } + measure { + measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2; + } + + emit measurename2; + } + + schema schema { + a string; + b i32; + } + + source local_files local { + items = [ + { uri_file: "x.parquet" parquet: {} } + ] + } + + extension_space { + function sum:i32 as sum; + })", + AllOf( + HasSymbolsWithTypes( + {"read", "aggregate", "root"}, {SymbolType::kRelation}), + HasErrors({}), + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>(R"( + extensions { + extension_function { + name: "sum:i32" + } + } + relations { + root { + input { + aggregate { + common { + emit { + output_mapping: 1 + } + } + input { + read { + common { + direct { } + } + base_schema { + names: 'a' names: 'b' + struct { + types { + string { nullability: NULLABILITY_REQUIRED } + } + types { + i32 { nullability: NULLABILITY_REQUIRED } + } + nullability: NULLABILITY_REQUIRED + } + } + local_files { + items { uri_file: 'x.parquet' parquet { } } + } + } + } + groupings { + } + measures { + measure { + output_type { + i32 { nullability: NULLABILITY_REQUIRED } + } + arguments { + value { + selection { + direct_reference { struct_field { field: 1 } } + root_reference: { } + } + } + } + } + } + measures { + measure { + output_type { + i32 { nullability: NULLABILITY_REQUIRED } + } + arguments { + value { + selection { + direct_reference { struct_field { field: 0 } } + root_reference: { } + } + } + } + } + } + } + } + } + })"))), + }, }; return cases; }