Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for emit in aggregate relations #122

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion src/substrait/textplan/PlanPrinterVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ std::string visitEnumArgument(const std::string& str) {
return text.str();
}

bool isAggregate(const SymbolInfo* symbol) {
if (symbol->subtype.type() == typeid(::substrait::proto::Rel::RelTypeCase) &&
ANY_CAST(::substrait::proto::Rel::RelTypeCase, symbol->subtype) ==
::substrait::proto::Rel::kAggregate) {
return true;
}
Comment on lines +75 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pattern matching like this, I recommend adding a conditional cast macro:

#define ANY_CAST_IF(ValueType, value) \
  value.type() != typeid(ValueType) \
  ? ::std::nullopt \
  : ::std::make_optional(ANY_CAST(ValueType, value))

Then you can write

Suggested change
if (symbol->subtype.type() == typeid(::substrait::proto::Rel::RelTypeCase) &&
ANY_CAST(::substrait::proto::Rel::RelTypeCase, symbol->subtype) ==
::substrait::proto::Rel::kAggregate) {
return true;
}
if (auto type_case = ANY_CAST_IF(::substrait::proto::Rel::RelTypeCase, symbol->subtype)) {
return type_case == ::substrait::proto::Rel::kAggregate;
}

If you want to be even fancier, I can provide a match() helper:

bool isAggregate(const SymbolInfo* symbol) {
  return match(symbol->subtype,
    [](::substrait::proto::Rel::RelTypeCase type_case) {
      return type_case == ::substrait::proto::Rel::kAggregate;
    },
    [](RelationType rel_type) {
      return rel_type == RelationType::kAggregate;
    },
    [](...) {
      return false;
    });
}

if (symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate) {
return true;
}
return false;
}

} // namespace

std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) {
Expand Down Expand Up @@ -177,6 +190,65 @@ std::string PlanPrinterVisitor::lookupFieldReference(
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified) {
if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) {
errorListener_->addError(
"Field number " + std::to_string(fieldReference) +
" mysteriously requested outside of a relation.");
return "field#" + std::to_string(fieldReference);
}
auto actualScope = currentScope;
if (stepsOut > 0) {
for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) {
auto actualParentQueryLocation = getParentQueryLocation(actualScope);
if (actualParentQueryLocation == Location::kUnknownLocation) {
errorListener_->addError(
"Requested steps out of " + std::to_string(stepsOut) +
" but not within subquery depth that high.");
return "field#" + std::to_string(fieldReference);
}
actualScope = symbolTable_->lookupSymbolByLocationAndType(
actualParentQueryLocation, SymbolType::kRelation);
if (actualScope == nullptr) {
errorListener_->addError(
"Internal error: Missing previously encountered parent query symbol.");
return "field#" + std::to_string(fieldReference);
}
}
}
auto relationData =
ANY_CAST(std::shared_ptr<RelationData>, actualScope->blob);
const SymbolInfo* symbol{nullptr};
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(currentScope) &&
fieldReference < relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference];
} else if (fieldReference < fieldReferencesSize) {
symbol = relationData->fieldReferences[fieldReference];
} else if (
fieldReference <
fieldReferencesSize + relationData->generatedFieldReferences.size()) {
symbol =
relationData
->generatedFieldReferences[fieldReference - fieldReferencesSize];
} else {
errorListener_->addError(
"Encountered field reference out of range: " +
std::to_string(fieldReference));
return "field#" + std::to_string(fieldReference);
}
Comment on lines +227 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems odd to me because I'd expect a clean outer branch on the type of currentScope before we look at the value of fieldReference. If currentScope is an aggregate but fieldReference is out of range, is it really reasonable to use the same code we would have used if currentScope were not an aggregate? I'd expect something like:

Suggested change
if (isAggregate(currentScope) &&
fieldReference < relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference];
} else if (fieldReference < fieldReferencesSize) {
symbol = relationData->fieldReferences[fieldReference];
} else if (
fieldReference <
fieldReferencesSize + relationData->generatedFieldReferences.size()) {
symbol =
relationData
->generatedFieldReferences[fieldReference - fieldReferencesSize];
} else {
errorListener_->addError(
"Encountered field reference out of range: " +
std::to_string(fieldReference));
return "field#" + std::to_string(fieldReference);
}
if (isAggregate(currentScope)) {
if (fieldReference < relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference];
}
} else {
auto size = relationData->fieldReferences.size();
if (fieldReference < size) {
symbol = relationData->fieldReferences[fieldReference];
} else if (
fieldReference <
size + relationData->generatedFieldReferences.size()) {
symbol =
relationData
->generatedFieldReferences[fieldReference - size];
}
}
if (symbol == nullptr) {
errorListener_->addError(
"Encountered field reference out of range: " +
std::to_string(fieldReference));
return "field#" + std::to_string(fieldReference);
}

in any case, please clarify the branching here

if (!symbol->alias.empty()) {
return symbol->alias;
} else if (needFullyQualified && symbol->schema != nullptr) {
return symbol->schema->name + "." + symbol->name;
}
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFunctionReference(
uint32_t function_reference) {
for (const auto& symbol : symbolTable_->getSymbols()) {
Expand Down Expand Up @@ -813,7 +885,7 @@ std::any PlanPrinterVisitor::visitRelationCommon(
}
for (const auto& mapping : common.emit().output_mapping()) {
text << " emit "
<< lookupFieldReference(
<< lookupFieldReferenceForEmit(
mapping, currentScope_, /* stepsOut= */ 0, true)
<< ";\n";
}
Expand Down Expand Up @@ -1028,6 +1100,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation(
}
text << " }\n";
}
text << ANY_CAST(std::string, visitRelationCommon(relation.common()));
return text.str();
}

Expand Down
7 changes: 7 additions & 0 deletions src/substrait/textplan/PlanPrinterVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor {
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFunctionReference(uint32_t function_reference);

std::any visitSubqueryScalar(
Expand Down
2 changes: 1 addition & 1 deletion src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (auto [shortName, canonicalName] : functionsToOutput) {
for (const auto& [shortName, canonicalName] : functionsToOutput) {
text << " function " << canonicalName << " as " << shortName << ";\n";
}
text << "}\n";
Expand Down
20 changes: 18 additions & 2 deletions src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema(
std::nullopt);
relationData->generatedFieldReferences.emplace_back(symbol);
}
// TODO -- If there are multiple groupings add the additional output.
// Aggregate relations are different in that they alter the emitted fields
// by default.
relationData->outputFieldReferences.insert(
Expand Down Expand Up @@ -629,9 +630,24 @@ void InitialPlanProtoVisitor::updateLocalSchema(
// Revamp the output based on the output mapping if present.
auto mapping = getOutputMapping(relation);
if (!mapping.empty()) {
// TODO -- Use a more explicit check.
if (!relationData->outputFieldReferences.empty()) {
errorListener_->addError(
"Aggregate relations do not yet support output mapping changes.");
// We are processing an aggregate node which is the only relation with
// output field references.
auto generatedFieldReferenceSize =
relationData->generatedFieldReferences.size();
relationData->outputFieldReferences.clear(); // Start over.
for (auto item : mapping) {
if (item < generatedFieldReferenceSize) {
relationData->outputFieldReferences.push_back(
relationData->generatedFieldReferences[item]);
} else {
// TODO -- Add support for grouping fields (needs text syntax).
errorListener_->addError(
"Asked to emit a field (" + std::to_string(item) +
" beyond what the aggregate produced.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
" beyond what the aggregate produced.");
") beyond what the aggregate produced.");

}
}
return;
}
for (auto item : mapping) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,59 @@ std::vector<TestCase> getTestCases() {
" hashjoin -> root;\n"
"}\n"))),
},
{
"aggregate with emits",
R"(extensions {
extension_function {
function_anchor: 0
name: "sum:i32"
}
}
relations: { root: { input: {
aggregate: { common { emit { output_mapping: 1 } } input {
read: { base_schema { names: 'a' names: 'b'
struct { types { string {} } types { i32 {} } } }
local_files { items { uri_file: 'x.parquet' parquet { } } }
} }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 0 } } } } } } }
} } } })",
AllOf(WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
read -> aggregate -> root;
}

read relation read {
source local;
base_schema schema;
}

aggregate relation aggregate {
measure {
measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename;
}
measure {
measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2;
}

emit measurename2;
}

schema schema {
a string;
b i32;
}

source local_files local {
items = [
{uri_file: "x.parquet" parquet: {}}
]
}

extension_space {
function sum:i32 as sum;
)"))),
},
};
return cases;
}
Expand Down
27 changes: 23 additions & 4 deletions src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) {
nullptr;
}

bool isAggregate(const SymbolInfo* symbol) {
return symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate;
}

} // namespace

std::any SubstraitPlanRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -819,7 +824,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit_ = true;
auto result = visitChildren(ctx);
this->processingEmit_ = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2023,6 +2030,9 @@ std::pair<int, int> SubstraitPlanRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit_) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2075,10 +2085,19 @@ void SubstraitPlanRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor {
bool hasSubquery(SubstraitPlanParser::ExpressionContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit_{false};
};

} // namespace io::substrait::textplan
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation";
const std::string kJoinTypePrefix = "jointype";
const std::string kSortDirectionPrefix = "sortdirection";

const std::string kIntermediateNodeName = "intermediate";

enum RelationFilterBehavior {
kDefault = 0,
kBestEffort = 1,
Expand Down Expand Up @@ -374,6 +372,11 @@ comparisonToProto(const std::string& text) {
Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED;
}

bool isAggregate(const SymbolInfo* symbol) {
return symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate;
Comment on lines +376 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return symbol->subtype.type() == typeid(RelationType) &&
ANY_CAST(RelationType, symbol->subtype) == RelationType::kAggregate;
return ANY_CAST_IF(RelationType, symbol->subtype) == RelationType::kAggregate;

}

} // namespace

std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -871,7 +874,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit_ = true;
auto result = visitChildren(ctx);
this->processingEmit_ = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2163,6 +2168,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit_) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2234,10 +2242,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class SubstraitPlanSubqueryRelationVisitor : public SubstraitPlanTypeVisitor {
bool isWithinSubquery(SubstraitPlanParser::RelationContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit_{false};
};

} // namespace io::substrait::textplan
2 changes: 2 additions & 0 deletions src/substrait/textplan/parser/SubstraitPlanVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ std::any SubstraitPlanVisitor::visitFile_detail(
item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText()));
} else if (ctx->ORC() != nullptr) {
item->mutable_orc();
} else if (ctx->PARQUET() != nullptr) {
item->mutable_parquet();
} else {
return visitChildren(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PARTITION_INDEX: 'PARTITION_INDEX';
START: 'START';
LENGTH: 'LENGTH';
ORC: 'ORC';
PARQUET: 'PARQUET';
NULLVAL: 'NULL';
TRUEVAL: 'TRUE';
FALSEVAL: 'FALSE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ file_detail
| START COLON NUMBER
| LENGTH COLON NUMBER
| ORC COLON LEFTBRACE RIGHTBRACE
| PARQUET COLON LEFTBRACE RIGHTBRACE
| file_location
;

Expand Down
Loading
Loading