Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for emit in aggregate relations #122

Merged
merged 16 commits into from
Oct 1, 2024
Merged
6 changes: 6 additions & 0 deletions src/substrait/textplan/Any.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ any_cast(const std::any& value, const char* file, int line) { // NOLINT
#define ANY_CAST(ValueType, Value) \
::io::substrait::textplan::any_cast<ValueType>(Value, __FILE__, __LINE__)

// Casts the any if it matches the given type otherwise it returns nullopt.
#define ANY_CAST_IF(ValueType, value) \
value.type() != typeid(ValueType) \
? ::std::nullopt \
: ::std::make_optional(ANY_CAST(ValueType, value))

} // namespace io::substrait::textplan
79 changes: 78 additions & 1 deletion src/substrait/textplan/PlanPrinterVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ std::string visitEnumArgument(const std::string& str) {
return text.str();
}

bool isAggregate(const SymbolInfo* symbol) {
if (const auto typeCase =
ANY_CAST_IF(::substrait::proto::Rel::RelTypeCase, symbol->subtype)) {
return (typeCase == ::substrait::proto::Rel::kAggregate);
}
if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) {
return (typeCase == RelationType::kAggregate);
}
return false;
}

} // namespace

std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) {
Expand Down Expand Up @@ -177,6 +188,71 @@ std::string PlanPrinterVisitor::lookupFieldReference(
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified) {
if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) {
errorListener_->addError(
"Field number " + std::to_string(fieldReference) +
" mysteriously requested outside of a relation.");
return "field#" + std::to_string(fieldReference);
}
auto actualScope = currentScope;
if (stepsOut > 0) {
for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) {
auto actualParentQueryLocation = getParentQueryLocation(actualScope);
if (actualParentQueryLocation == Location::kUnknownLocation) {
errorListener_->addError(
"Requested field#" + std::to_string(fieldReference) + " at " +
std::to_string(stepsOut) +
" steps out but subquery depth is only " +
std::to_string(stepsLeft));
return "field#" + std::to_string(fieldReference);
}
actualScope = symbolTable_->lookupSymbolByLocationAndType(
actualParentQueryLocation, SymbolType::kRelation);
if (actualScope == nullptr) {
errorListener_->addError(
"Internal error: Missing previously encountered parent query symbol.");
return "field#" + std::to_string(fieldReference);
}
}
}
auto relationData =
ANY_CAST(std::shared_ptr<RelationData>, actualScope->blob);
const SymbolInfo* symbol{nullptr};
const char* relationType = "non-aggregate";
if (isAggregate(currentScope)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer, thanks. It might be worthwhile to include the type of currentScope (and therefore the range of fieldReferences which would have been valid) below in the out of range error message

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree that better error messages are useful but I believe the substrait validator is better suited for helping developers figure what is wrong with their plan generation (allowing this code to be simpler). I've added the type here though.

relationType = "aggregate";
if (fieldReference < relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference];
}
} else {
auto size = relationData->fieldReferences.size();
if (fieldReference < size) {
symbol = relationData->fieldReferences[fieldReference];
} else if (
fieldReference < size + relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference - size];
}
}
if (symbol == nullptr) {
errorListener_->addError(
"Encountered field reference out of range in " +
std::string(relationType) +
" relation: " + std::to_string(fieldReference));
return "field#" + std::to_string(fieldReference);
}
if (!symbol->alias.empty()) {
return symbol->alias;
} else if (needFullyQualified && symbol->schema != nullptr) {
return symbol->schema->name + "." + symbol->name;
}
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFunctionReference(
uint32_t function_reference) {
for (const auto& symbol : symbolTable_->getSymbols()) {
Expand Down Expand Up @@ -813,7 +889,7 @@ std::any PlanPrinterVisitor::visitRelationCommon(
}
for (const auto& mapping : common.emit().output_mapping()) {
text << " emit "
<< lookupFieldReference(
<< lookupFieldReferenceForEmit(
mapping, currentScope_, /* stepsOut= */ 0, true)
<< ";\n";
}
Expand Down Expand Up @@ -1028,6 +1104,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation(
}
text << " }\n";
}
text << ANY_CAST(std::string, visitRelationCommon(relation.common()));
return text.str();
}

Expand Down
7 changes: 7 additions & 0 deletions src/substrait/textplan/PlanPrinterVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor {
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFunctionReference(uint32_t function_reference);

std::any visitSubqueryScalar(
Expand Down
2 changes: 1 addition & 1 deletion src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (auto [shortName, canonicalName] : functionsToOutput) {
for (const auto& [shortName, canonicalName] : functionsToOutput) {
text << " function " << canonicalName << " as " << shortName << ";\n";
}
text << "}\n";
Expand Down
19 changes: 16 additions & 3 deletions src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema(
std::nullopt);
relationData->generatedFieldReferences.emplace_back(symbol);
}
// TODO -- If there are multiple groupings add the additional output.
// Aggregate relations are different in that they alter the emitted fields
// by default.
relationData->outputFieldReferences.insert(
Expand Down Expand Up @@ -629,9 +630,21 @@ void InitialPlanProtoVisitor::updateLocalSchema(
// Revamp the output based on the output mapping if present.
auto mapping = getOutputMapping(relation);
if (!mapping.empty()) {
if (!relationData->outputFieldReferences.empty()) {
errorListener_->addError(
"Aggregate relations do not yet support output mapping changes.");
if (relation.rel_type_case() == ::substrait::proto::Rel::kAggregate) {
auto generatedFieldReferenceSize =
relationData->generatedFieldReferences.size();
relationData->outputFieldReferences.clear(); // Start over.
for (auto item : mapping) {
if (item < generatedFieldReferenceSize) {
relationData->outputFieldReferences.push_back(
relationData->generatedFieldReferences[item]);
} else {
// TODO -- Add support for grouping fields (needs text syntax).
errorListener_->addError(
"Asked to emit a field (" + std::to_string(item) +
") beyond what the aggregate produced.");
}
}
return;
}
for (auto item : mapping) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,59 @@ std::vector<TestCase> getTestCases() {
" hashjoin -> root;\n"
"}\n"))),
},
{
"aggregate with emits",
R"(extensions {
extension_function {
function_anchor: 0
name: "sum:i32"
}
}
relations: { root: { input: {
aggregate: { common { emit { output_mapping: 1 } } input {
read: { base_schema { names: 'a' names: 'b'
struct { types { string {} } types { i32 {} } } }
local_files { items { uri_file: 'x.parquet' parquet { } } }
} }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 0 } } } } } } }
} } } })",
AllOf(WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
read -> aggregate -> root;
}

read relation read {
source local;
base_schema schema;
}

aggregate relation aggregate {
measure {
measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename;
}
measure {
measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2;
}

emit measurename2;
}

schema schema {
a string;
b i32;
}

source local_files local {
items = [
{uri_file: "x.parquet" parquet: {}}
]
}

extension_space {
function sum:i32 as sum;
)"))),
},
};
return cases;
}
Expand Down
29 changes: 25 additions & 4 deletions src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) {
nullptr;
}

bool isAggregate(const SymbolInfo* symbol) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there two implementations of this function (PlanPrinterVisitor.cpp:74)? The definitions are even different. If the difference is intentional, that's definitely worth a comment or at least renaming the functions to clarify their distinct behaviors. If the difference is not intentional, could this become SymbolInfo::isAggregate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term there shouldn't be this function at all and just a single type in the representation. Right now the intermediate form can have both types (depending on how it was parsed) which needs to be addressed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, a comment would definitely be advisable just to give future readers a clue that this is a stopgap and that in the future there should only be RelationType::kAggregate

if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) {
return (typeCase == RelationType::kAggregate);
}
return false;
}

} // namespace

std::any SubstraitPlanRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -819,7 +826,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit_ = true;
auto result = visitChildren(ctx);
this->processingEmit_ = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2023,6 +2032,9 @@ std::pair<int, int> SubstraitPlanRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit_) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2075,10 +2087,19 @@ void SubstraitPlanRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to the code at the end of InitialPlanProtoVisitor::updateLocalSchema. Any chance it could be extracted to a helper function?

auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor {
bool hasSubquery(SubstraitPlanParser::ExpressionContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit_{false};
};

} // namespace io::substrait::textplan
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation";
const std::string kJoinTypePrefix = "jointype";
const std::string kSortDirectionPrefix = "sortdirection";

const std::string kIntermediateNodeName = "intermediate";

enum RelationFilterBehavior {
kDefault = 0,
kBestEffort = 1,
Expand Down Expand Up @@ -374,6 +372,13 @@ comparisonToProto(const std::string& text) {
Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED;
}

bool isAggregate(const SymbolInfo* symbol) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) {
return (typeCase == RelationType::kAggregate);
}
return false;
}

} // namespace

std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -871,7 +876,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit_ = true;
auto result = visitChildren(ctx);
this->processingEmit_ = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2163,6 +2170,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit_) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2234,10 +2244,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more ditto

auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class SubstraitPlanSubqueryRelationVisitor : public SubstraitPlanTypeVisitor {
bool isWithinSubquery(SubstraitPlanParser::RelationContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit_{false};
};

} // namespace io::substrait::textplan
2 changes: 2 additions & 0 deletions src/substrait/textplan/parser/SubstraitPlanVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ std::any SubstraitPlanVisitor::visitFile_detail(
item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText()));
} else if (ctx->ORC() != nullptr) {
item->mutable_orc();
} else if (ctx->PARQUET() != nullptr) {
item->mutable_parquet();
} else {
return visitChildren(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PARTITION_INDEX: 'PARTITION_INDEX';
START: 'START';
LENGTH: 'LENGTH';
ORC: 'ORC';
PARQUET: 'PARQUET';
NULLVAL: 'NULL';
TRUEVAL: 'TRUE';
FALSEVAL: 'FALSE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ file_detail
| START COLON NUMBER
| LENGTH COLON NUMBER
| ORC COLON LEFTBRACE RIGHTBRACE
| PARQUET COLON LEFTBRACE RIGHTBRACE
| file_location
;

Expand Down
Loading
Loading