Skip to content

Commit

Permalink
feat: implement emit for aggregate relations
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime committed Sep 12, 2024
1 parent 0376088 commit 8a62ae6
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 14 deletions.
75 changes: 74 additions & 1 deletion src/substrait/textplan/PlanPrinterVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ std::string visitEnumArgument(const std::string& str) {
return text.str();
}

bool isAggregate(const SymbolInfo* symbol) {
if (symbol->subtype.type() == typeid(::substrait::proto::Rel::RelTypeCase) &&
ANY_CAST(::substrait::proto::Rel::RelTypeCase, symbol->subtype) ==
::substrait::proto::Rel::kAggregate) {
return true;
}
if (symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate) {
return true;
}
return false;
}

} // namespace

std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) {
Expand Down Expand Up @@ -177,6 +190,65 @@ std::string PlanPrinterVisitor::lookupFieldReference(
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified) {
if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) {
errorListener_->addError(
"Field number " + std::to_string(fieldReference) +
" mysteriously requested outside of a relation.");
return "field#" + std::to_string(fieldReference);
}
auto actualScope = currentScope;
if (stepsOut > 0) {
for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) {
auto actualParentQueryLocation = getParentQueryLocation(actualScope);
if (actualParentQueryLocation == Location::kUnknownLocation) {
errorListener_->addError(
"Requested steps out of " + std::to_string(stepsOut) +
" but not within subquery depth that high.");
return "field#" + std::to_string(fieldReference);
}
actualScope = symbolTable_->lookupSymbolByLocationAndType(
actualParentQueryLocation, SymbolType::kRelation);
if (actualScope == nullptr) {
errorListener_->addError(
"Internal error: Missing previously encountered parent query symbol.");
return "field#" + std::to_string(fieldReference);
}
}
}
auto relationData =
ANY_CAST(std::shared_ptr<RelationData>, actualScope->blob);
const SymbolInfo* symbol{nullptr};
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(currentScope) &&
fieldReference < relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference];
} else if (fieldReference < fieldReferencesSize) {
symbol = relationData->fieldReferences[fieldReference];
} else if (
fieldReference <
fieldReferencesSize + relationData->generatedFieldReferences.size()) {
symbol =
relationData
->generatedFieldReferences[fieldReference - fieldReferencesSize];
} else {
errorListener_->addError(
"Encountered field reference out of range: " +
std::to_string(fieldReference));
return "field#" + std::to_string(fieldReference);
}
if (!symbol->alias.empty()) {
return symbol->alias;
} else if (needFullyQualified && symbol->schema != nullptr) {
return symbol->schema->name + "." + symbol->name;
}
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFunctionReference(
uint32_t function_reference) {
for (const auto& symbol : symbolTable_->getSymbols()) {
Expand Down Expand Up @@ -813,7 +885,7 @@ std::any PlanPrinterVisitor::visitRelationCommon(
}
for (const auto& mapping : common.emit().output_mapping()) {
text << " emit "
<< lookupFieldReference(
<< lookupFieldReferenceForEmit(
mapping, currentScope_, /* stepsOut= */ 0, true)
<< ";\n";
}
Expand Down Expand Up @@ -1028,6 +1100,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation(
}
text << " }\n";
}
text << ANY_CAST(std::string, visitRelationCommon(relation.common()));
return text.str();
}

Expand Down
7 changes: 7 additions & 0 deletions src/substrait/textplan/PlanPrinterVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor {
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFunctionReference(uint32_t function_reference);

std::any visitSubqueryScalar(
Expand Down
2 changes: 1 addition & 1 deletion src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (auto [shortName, canonicalName] : functionsToOutput) {
for (const auto& [shortName, canonicalName] : functionsToOutput) {
text << " function " << canonicalName << " as " << shortName << ";\n";
}
text << "}\n";
Expand Down
20 changes: 18 additions & 2 deletions src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema(
std::nullopt);
relationData->generatedFieldReferences.emplace_back(symbol);
}
// TODO -- If there are multiple groupings add the additional output.
// Aggregate relations are different in that they alter the emitted fields
// by default.
relationData->outputFieldReferences.insert(
Expand Down Expand Up @@ -629,9 +630,24 @@ void InitialPlanProtoVisitor::updateLocalSchema(
// Revamp the output based on the output mapping if present.
auto mapping = getOutputMapping(relation);
if (!mapping.empty()) {
// TODO -- Use a more explicit check.
if (!relationData->outputFieldReferences.empty()) {
errorListener_->addError(
"Aggregate relations do not yet support output mapping changes.");
// We are processing an aggregate node which is the only relation with
// output field references.
auto generatedFieldReferenceSize =
relationData->generatedFieldReferences.size();
relationData->outputFieldReferences.clear(); // Start over.
for (auto item : mapping) {
if (item < generatedFieldReferenceSize) {
relationData->outputFieldReferences.push_back(
relationData->generatedFieldReferences[item]);
} else {
// TODO -- Add support for grouping fields (needs text syntax).
errorListener_->addError(
"Asked to emit a field (" + std::to_string(item) +
" beyond what the aggregate produced.");
}
}
return;
}
for (auto item : mapping) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,60 @@ std::vector<TestCase> getTestCases() {
" hashjoin -> root;\n"
"}\n"))),
},
{
"aggregate with emits",
R"(extensions {
extension_function {
function_anchor: 0
name: "sum:i32"
}
}
relations: { root: { input: {
aggregate: { common { emit { output_mapping: 1 } } input {
read: { base_schema { names: 'a' names: 'b'
struct { types { string {} } types { i32 {} } } }
local_files { items { uri_file: 'x.parquet' parquet { } } }
} }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 0 } } } } } } }
} } } })",
AllOf(
WhenSerialized(EqSquashingWhitespace(
"pipelines {\n"
" read -> aggregate -> root;\n"
"}\n"
"\n"
"read relation read {\n"
" source local;\n"
" base_schema schema;\n"
"}\n"
"\n"
"aggregate relation aggregate {\n"
" measure {\n"
" measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename;\n"
" }\n"
" measure {\n"
" measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2;\n"
" }\n"
"\n"
" emit measurename2;\n"
"}\n"
"\n"
"schema schema {\n"
" a string;\n"
" b i32;\n"
"}\n"
"\n"
"source local_files local {\n"
" items = [\n"
" {uri_file: \"x.parquet\" parquet: {}}\n"
" ]\n"
"}\n"
"\n"
"extension_space {\n"
" function sum:i32 as sum;\n"
"}\n"))),
},
};
return cases;
}
Expand Down
27 changes: 23 additions & 4 deletions src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) {
nullptr;
}

bool isAggregate(const SymbolInfo* symbol) {
return symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate;
}

} // namespace

std::any SubstraitPlanRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -819,7 +824,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit = true;
auto result = visitChildren(ctx);
this->processingEmit = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2023,6 +2030,9 @@ std::pair<int, int> SubstraitPlanRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2075,10 +2085,19 @@ void SubstraitPlanRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor {
bool hasSubquery(SubstraitPlanParser::ExpressionContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit{false};
};

} // namespace io::substrait::textplan
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation";
const std::string kJoinTypePrefix = "jointype";
const std::string kSortDirectionPrefix = "sortdirection";

const std::string kIntermediateNodeName = "intermediate";

enum RelationFilterBehavior {
kDefault = 0,
kBestEffort = 1,
Expand Down Expand Up @@ -374,6 +372,11 @@ comparisonToProto(const std::string& text) {
Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED;
}

bool isAggregate(const SymbolInfo* symbol) {
return symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate;
}

} // namespace

std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -871,7 +874,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit = true;
auto result = visitChildren(ctx);
this->processingEmit = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2163,6 +2168,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2234,10 +2242,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class SubstraitPlanSubqueryRelationVisitor : public SubstraitPlanTypeVisitor {
bool isWithinSubquery(SubstraitPlanParser::RelationContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit{false};
};

} // namespace io::substrait::textplan
2 changes: 2 additions & 0 deletions src/substrait/textplan/parser/SubstraitPlanVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ std::any SubstraitPlanVisitor::visitFile_detail(
item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText()));
} else if (ctx->ORC() != nullptr) {
item->mutable_orc();
} else if (ctx->PARQUET() != nullptr) {
item->mutable_parquet();
} else {
return visitChildren(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PARTITION_INDEX: 'PARTITION_INDEX';
START: 'START';
LENGTH: 'LENGTH';
ORC: 'ORC';
PARQUET: 'PARQUET';
NULLVAL: 'NULL';
TRUEVAL: 'TRUE';
FALSEVAL: 'FALSE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ file_detail
| START COLON NUMBER
| LENGTH COLON NUMBER
| ORC COLON LEFTBRACE RIGHTBRACE
| PARQUET COLON LEFTBRACE RIGHTBRACE
| file_location
;

Expand Down
Loading

0 comments on commit 8a62ae6

Please sign in to comment.