-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add support for emit in aggregate relations #122
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -71,6 +71,19 @@ std::string visitEnumArgument(const std::string& str) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return text.str(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bool isAggregate(const SymbolInfo* symbol) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (symbol->subtype.type() == typeid(::substrait::proto::Rel::RelTypeCase) && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ANY_CAST(::substrait::proto::Rel::RelTypeCase, symbol->subtype) == | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
::substrait::proto::Rel::kAggregate) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (symbol->subtype.type() == typeid(RelationType) && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} // namespace | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -177,6 +190,65 @@ std::string PlanPrinterVisitor::lookupFieldReference( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return symbol->name; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::string PlanPrinterVisitor::lookupFieldReferenceForEmit( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t fieldReference, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const SymbolInfo* currentScope, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t stepsOut, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bool needFullyQualified) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
errorListener_->addError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Field number " + std::to_string(fieldReference) + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
" mysteriously requested outside of a relation."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "field#" + std::to_string(fieldReference); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto actualScope = currentScope; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (stepsOut > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto actualParentQueryLocation = getParentQueryLocation(actualScope); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (actualParentQueryLocation == Location::kUnknownLocation) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
errorListener_->addError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Requested steps out of " + std::to_string(stepsOut) + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
" but not within subquery depth that high."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "field#" + std::to_string(fieldReference); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
actualScope = symbolTable_->lookupSymbolByLocationAndType( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
actualParentQueryLocation, SymbolType::kRelation); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (actualScope == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
errorListener_->addError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Internal error: Missing previously encountered parent query symbol."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "field#" + std::to_string(fieldReference); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto relationData = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ANY_CAST(std::shared_ptr<RelationData>, actualScope->blob); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const SymbolInfo* symbol{nullptr}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto fieldReferencesSize = relationData->fieldReferences.size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (isAggregate(currentScope) && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
fieldReference < relationData->generatedFieldReferences.size()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
symbol = relationData->generatedFieldReferences[fieldReference]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (fieldReference < fieldReferencesSize) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
symbol = relationData->fieldReferences[fieldReference]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
fieldReference < | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
fieldReferencesSize + relationData->generatedFieldReferences.size()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
symbol = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
relationData | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
->generatedFieldReferences[fieldReference - fieldReferencesSize]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
errorListener_->addError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Encountered field reference out of range: " + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::to_string(fieldReference)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "field#" + std::to_string(fieldReference); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+227
to
+243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems odd to me because I'd expect a clean outer branch on the type of
Suggested change
in any case, please clarify the branching here |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!symbol->alias.empty()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return symbol->alias; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if (needFullyQualified && symbol->schema != nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return symbol->schema->name + "." + symbol->name; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return symbol->name; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::string PlanPrinterVisitor::lookupFunctionReference( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uint32_t function_reference) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (const auto& symbol : symbolTable_->getSymbols()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -813,7 +885,7 @@ std::any PlanPrinterVisitor::visitRelationCommon( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (const auto& mapping : common.emit().output_mapping()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
text << " emit " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
<< lookupFieldReference( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
<< lookupFieldReferenceForEmit( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mapping, currentScope_, /* stepsOut= */ 0, true) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
<< ";\n"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1028,6 +1100,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
text << " }\n"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
text << ANY_CAST(std::string, visitRelationCommon(relation.common())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return text.str(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema( | |||||
std::nullopt); | ||||||
relationData->generatedFieldReferences.emplace_back(symbol); | ||||||
} | ||||||
// TODO -- If there are multiple groupings add the additional output. | ||||||
// Aggregate relations are different in that they alter the emitted fields | ||||||
// by default. | ||||||
relationData->outputFieldReferences.insert( | ||||||
|
@@ -629,9 +630,24 @@ void InitialPlanProtoVisitor::updateLocalSchema( | |||||
// Revamp the output based on the output mapping if present. | ||||||
auto mapping = getOutputMapping(relation); | ||||||
if (!mapping.empty()) { | ||||||
// TODO -- Use a more explicit check. | ||||||
if (!relationData->outputFieldReferences.empty()) { | ||||||
errorListener_->addError( | ||||||
"Aggregate relations do not yet support output mapping changes."); | ||||||
// We are processing an aggregate node which is the only relation with | ||||||
// output field references. | ||||||
auto generatedFieldReferenceSize = | ||||||
relationData->generatedFieldReferences.size(); | ||||||
relationData->outputFieldReferences.clear(); // Start over. | ||||||
for (auto item : mapping) { | ||||||
if (item < generatedFieldReferenceSize) { | ||||||
relationData->outputFieldReferences.push_back( | ||||||
relationData->generatedFieldReferences[item]); | ||||||
} else { | ||||||
// TODO -- Add support for grouping fields (needs text syntax). | ||||||
errorListener_->addError( | ||||||
"Asked to emit a field (" + std::to_string(item) + | ||||||
" beyond what the aggregate produced."); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
} | ||||||
} | ||||||
return; | ||||||
} | ||||||
for (auto item : mapping) { | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation"; | |||||||
const std::string kJoinTypePrefix = "jointype"; | ||||||||
const std::string kSortDirectionPrefix = "sortdirection"; | ||||||||
|
||||||||
const std::string kIntermediateNodeName = "intermediate"; | ||||||||
|
||||||||
enum RelationFilterBehavior { | ||||||||
kDefault = 0, | ||||||||
kBestEffort = 1, | ||||||||
|
@@ -374,6 +372,11 @@ comparisonToProto(const std::string& text) { | |||||||
Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED; | ||||||||
} | ||||||||
|
||||||||
bool isAggregate(const SymbolInfo* symbol) { | ||||||||
return symbol->subtype.type() == typeid(RelationType) && | ||||||||
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate; | ||||||||
Comment on lines
+376
to
+377
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
} | ||||||||
|
||||||||
} // namespace | ||||||||
|
||||||||
std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult( | ||||||||
|
@@ -871,7 +874,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit( | |||||||
SymbolType::kRelation); | ||||||||
auto parentRelationData = | ||||||||
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob); | ||||||||
this->processingEmit_ = true; | ||||||||
auto result = visitChildren(ctx); | ||||||||
this->processingEmit_ = false; | ||||||||
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); | ||||||||
auto common = | ||||||||
findCommonRelation(parentRelationType, &parentRelationData->relation); | ||||||||
|
@@ -2163,6 +2168,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName( | |||||||
std::shared_ptr<RelationData>& relationData, | ||||||||
const std::string& name) { | ||||||||
auto fieldReferencesSize = relationData->fieldReferences.size(); | ||||||||
if (isAggregate(symbol) && this->processingEmit_) { | ||||||||
fieldReferencesSize = 0; | ||||||||
} | ||||||||
|
||||||||
auto generatedField = std::find_if( | ||||||||
relationData->generatedFieldReferences.rbegin(), | ||||||||
|
@@ -2234,10 +2242,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema( | |||||||
if (common->emit().output_mapping_size() == 0) { | ||||||||
common->mutable_direct(); | ||||||||
} else { | ||||||||
if (!relationData->outputFieldReferences.empty()) { | ||||||||
// TODO -- Add support for aggregate relations. | ||||||||
errorListener_->addError( | ||||||||
token, "Aggregate relations do not yet support emit sections."); | ||||||||
if (relationData->relation.has_aggregate()) { | ||||||||
auto oldReferences = relationData->outputFieldReferences; | ||||||||
relationData->outputFieldReferences.clear(); | ||||||||
for (auto mapping : common->emit().output_mapping()) { | ||||||||
if (mapping < oldReferences.size()) { | ||||||||
relationData->outputFieldReferences.push_back(oldReferences[mapping]); | ||||||||
} else { | ||||||||
errorListener_->addError( | ||||||||
token, | ||||||||
"Field #" + std::to_string(mapping) + " requested but only " + | ||||||||
std::to_string(oldReferences.size()) + " are available."); | ||||||||
} | ||||||||
} | ||||||||
return; | ||||||||
} | ||||||||
for (auto mapping : common->emit().output_mapping()) { | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For pattern matching like this, I recommend adding a conditional cast macro:
Then you can write
If you want to be even fancier, I can provide a
match()
helper: