Skip to content

Commit

Permalink
feat: add support for emit in aggregate relations (#122)
Browse files Browse the repository at this point in the history
This adds support for emit in aggregations where there is no more than
one grouping section.

This addresses #121 .
  • Loading branch information
EpsilonPrime authored Oct 1, 2024
1 parent 3f633f2 commit 68036ca
Show file tree
Hide file tree
Showing 14 changed files with 334 additions and 15 deletions.
6 changes: 6 additions & 0 deletions src/substrait/textplan/Any.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ any_cast(const std::any& value, const char* file, int line) { // NOLINT
#define ANY_CAST(ValueType, Value) \
::io::substrait::textplan::any_cast<ValueType>(Value, __FILE__, __LINE__)

// Casts the any if it matches the given type otherwise it returns nullopt.
#define ANY_CAST_IF(ValueType, value) \
value.type() != typeid(ValueType) \
? ::std::nullopt \
: ::std::make_optional(ANY_CAST(ValueType, value))

} // namespace io::substrait::textplan
80 changes: 79 additions & 1 deletion src/substrait/textplan/PlanPrinterVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ std::string visitEnumArgument(const std::string& str) {
return text.str();
}

bool isAggregate(const SymbolInfo* symbol) {
// TODO: Remove after the relation type is one type internally.
if (const auto typeCase =
ANY_CAST_IF(::substrait::proto::Rel::RelTypeCase, symbol->subtype)) {
return (typeCase == ::substrait::proto::Rel::kAggregate);
}
if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) {
return (typeCase == RelationType::kAggregate);
}
return false;
}

} // namespace

std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) {
Expand Down Expand Up @@ -177,6 +189,71 @@ std::string PlanPrinterVisitor::lookupFieldReference(
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified) {
if (currentScope == nullptr || *currentScope_ == SymbolInfo::kUnknown) {
errorListener_->addError(
"Field number " + std::to_string(fieldReference) +
" mysteriously requested outside of a relation.");
return "field#" + std::to_string(fieldReference);
}
auto actualScope = currentScope;
if (stepsOut > 0) {
for (auto stepsLeft = stepsOut; stepsLeft > 0; stepsLeft--) {
auto actualParentQueryLocation = getParentQueryLocation(actualScope);
if (actualParentQueryLocation == Location::kUnknownLocation) {
errorListener_->addError(
"Requested field#" + std::to_string(fieldReference) + " at " +
std::to_string(stepsOut) +
" steps out but subquery depth is only " +
std::to_string(stepsLeft));
return "field#" + std::to_string(fieldReference);
}
actualScope = symbolTable_->lookupSymbolByLocationAndType(
actualParentQueryLocation, SymbolType::kRelation);
if (actualScope == nullptr) {
errorListener_->addError(
"Internal error: Missing previously encountered parent query symbol.");
return "field#" + std::to_string(fieldReference);
}
}
}
auto relationData =
ANY_CAST(std::shared_ptr<RelationData>, actualScope->blob);
const SymbolInfo* symbol{nullptr};
const char* relationType = "non-aggregate";
if (isAggregate(currentScope)) {
relationType = "aggregate";
if (fieldReference < relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference];
}
} else {
auto size = relationData->fieldReferences.size();
if (fieldReference < size) {
symbol = relationData->fieldReferences[fieldReference];
} else if (
fieldReference < size + relationData->generatedFieldReferences.size()) {
symbol = relationData->generatedFieldReferences[fieldReference - size];
}
}
if (symbol == nullptr) {
errorListener_->addError(
"Encountered field reference out of range in " +
std::string(relationType) +
" relation: " + std::to_string(fieldReference));
return "field#" + std::to_string(fieldReference);
}
if (!symbol->alias.empty()) {
return symbol->alias;
} else if (needFullyQualified && symbol->schema != nullptr) {
return symbol->schema->name + "." + symbol->name;
}
return symbol->name;
}

std::string PlanPrinterVisitor::lookupFunctionReference(
uint32_t function_reference) {
for (const auto& symbol : symbolTable_->getSymbols()) {
Expand Down Expand Up @@ -813,7 +890,7 @@ std::any PlanPrinterVisitor::visitRelationCommon(
}
for (const auto& mapping : common.emit().output_mapping()) {
text << " emit "
<< lookupFieldReference(
<< lookupFieldReferenceForEmit(
mapping, currentScope_, /* stepsOut= */ 0, true)
<< ";\n";
}
Expand Down Expand Up @@ -1028,6 +1105,7 @@ std::any PlanPrinterVisitor::visitAggregateRelation(
}
text << " }\n";
}
text << ANY_CAST(std::string, visitRelationCommon(relation.common()));
return text.str();
}

Expand Down
7 changes: 7 additions & 0 deletions src/substrait/textplan/PlanPrinterVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor {
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFieldReferenceForEmit(
uint32_t fieldReference,
const SymbolInfo* currentScope,
uint32_t stepsOut,
bool needFullyQualified);

std::string lookupFunctionReference(uint32_t function_reference);

std::any visitSubqueryScalar(
Expand Down
2 changes: 1 addition & 1 deletion src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (auto [shortName, canonicalName] : functionsToOutput) {
for (const auto& [shortName, canonicalName] : functionsToOutput) {
text << " function " << canonicalName << " as " << shortName << ";\n";
}
text << "}\n";
Expand Down
19 changes: 16 additions & 3 deletions src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ void InitialPlanProtoVisitor::updateLocalSchema(
std::nullopt);
relationData->generatedFieldReferences.emplace_back(symbol);
}
// TODO -- If there are multiple groupings add the additional output.
// Aggregate relations are different in that they alter the emitted fields
// by default.
relationData->outputFieldReferences.insert(
Expand Down Expand Up @@ -629,9 +630,21 @@ void InitialPlanProtoVisitor::updateLocalSchema(
// Revamp the output based on the output mapping if present.
auto mapping = getOutputMapping(relation);
if (!mapping.empty()) {
if (!relationData->outputFieldReferences.empty()) {
errorListener_->addError(
"Aggregate relations do not yet support output mapping changes.");
if (relation.rel_type_case() == ::substrait::proto::Rel::kAggregate) {
auto generatedFieldReferenceSize =
relationData->generatedFieldReferences.size();
relationData->outputFieldReferences.clear(); // Start over.
for (auto item : mapping) {
if (item < generatedFieldReferenceSize) {
relationData->outputFieldReferences.push_back(
relationData->generatedFieldReferences[item]);
} else {
// TODO -- Add support for grouping fields (needs text syntax).
errorListener_->addError(
"Asked to emit a field (" + std::to_string(item) +
") beyond what the aggregate produced.");
}
}
return;
}
for (auto item : mapping) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,59 @@ std::vector<TestCase> getTestCases() {
" hashjoin -> root;\n"
"}\n"))),
},
{
"aggregate with emits",
R"(extensions {
extension_function {
function_anchor: 0
name: "sum:i32"
}
}
relations: { root: { input: {
aggregate: { common { emit { output_mapping: 1 } } input {
read: { base_schema { names: 'a' names: 'b'
struct { types { string {} } types { i32 {} } } }
local_files { items { uri_file: 'x.parquet' parquet { } } }
} }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } }
measures { measure { output_type { i32 {} } arguments { value { selection { direct_reference { struct_field { field: 0 } } } } } } }
} } } })",
AllOf(WhenSerialized(EqSquashingWhitespace(
R"(pipelines {
read -> aggregate -> root;
}
read relation read {
source local;
base_schema schema;
}
aggregate relation aggregate {
measure {
measure sum(schema.b)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename;
}
measure {
measure sum(schema.a)->i32@AGGREGATION_PHASE_UNSPECIFIED NAMED measurename2;
}
emit measurename2;
}
schema schema {
a string;
b i32;
}
source local_files local {
items = [
{uri_file: "x.parquet" parquet: {}}
]
}
extension_space {
function sum:i32 as sum;
)"))),
},
};
return cases;
}
Expand Down
30 changes: 26 additions & 4 deletions src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) {
nullptr;
}

bool isAggregate(const SymbolInfo* symbol) {
// TODO: Remove once relation types have a unified type internally.
if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) {
return (typeCase == RelationType::kAggregate);
}
return false;
}

} // namespace

std::any SubstraitPlanRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -819,7 +827,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit_ = true;
auto result = visitChildren(ctx);
this->processingEmit_ = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2023,6 +2033,9 @@ std::pair<int, int> SubstraitPlanRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit_) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2075,10 +2088,19 @@ void SubstraitPlanRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor {
bool hasSubquery(SubstraitPlanParser::ExpressionContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit_{false};
};

} // namespace io::substrait::textplan
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ const std::string kAggregationInvocationPrefix = "aggregationinvocation";
const std::string kJoinTypePrefix = "jointype";
const std::string kSortDirectionPrefix = "sortdirection";

const std::string kIntermediateNodeName = "intermediate";

enum RelationFilterBehavior {
kDefault = 0,
kBestEffort = 1,
Expand Down Expand Up @@ -374,6 +372,14 @@ comparisonToProto(const std::string& text) {
Expression_Subquery_SetComparison_ComparisonOp_COMPARISON_OP_UNSPECIFIED;
}

bool isAggregate(const SymbolInfo* symbol) {
// TODO: Remove after the relation type is one type internally.
if (const auto typeCase = ANY_CAST_IF(RelationType, symbol->subtype)) {
return (typeCase == RelationType::kAggregate);
}
return false;
}

} // namespace

std::any SubstraitPlanSubqueryRelationVisitor::aggregateResult(
Expand Down Expand Up @@ -871,7 +877,9 @@ std::any SubstraitPlanSubqueryRelationVisitor::visitRelationEmit(
SymbolType::kRelation);
auto parentRelationData =
ANY_CAST(std::shared_ptr<RelationData>, parentSymbol->blob);
this->processingEmit_ = true;
auto result = visitChildren(ctx);
this->processingEmit_ = false;
auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype);
auto common =
findCommonRelation(parentRelationType, &parentRelationData->relation);
Expand Down Expand Up @@ -2163,6 +2171,9 @@ SubstraitPlanSubqueryRelationVisitor::findFieldReferenceByName(
std::shared_ptr<RelationData>& relationData,
const std::string& name) {
auto fieldReferencesSize = relationData->fieldReferences.size();
if (isAggregate(symbol) && this->processingEmit_) {
fieldReferencesSize = 0;
}

auto generatedField = std::find_if(
relationData->generatedFieldReferences.rbegin(),
Expand Down Expand Up @@ -2234,10 +2245,19 @@ void SubstraitPlanSubqueryRelationVisitor::applyOutputMappingToSchema(
if (common->emit().output_mapping_size() == 0) {
common->mutable_direct();
} else {
if (!relationData->outputFieldReferences.empty()) {
// TODO -- Add support for aggregate relations.
errorListener_->addError(
token, "Aggregate relations do not yet support emit sections.");
if (relationData->relation.has_aggregate()) {
auto oldReferences = relationData->outputFieldReferences;
relationData->outputFieldReferences.clear();
for (auto mapping : common->emit().output_mapping()) {
if (mapping < oldReferences.size()) {
relationData->outputFieldReferences.push_back(oldReferences[mapping]);
} else {
errorListener_->addError(
token,
"Field #" + std::to_string(mapping) + " requested but only " +
std::to_string(oldReferences.size()) + " are available.");
}
}
return;
}
for (auto mapping : common->emit().output_mapping()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class SubstraitPlanSubqueryRelationVisitor : public SubstraitPlanTypeVisitor {
bool isWithinSubquery(SubstraitPlanParser::RelationContext* ctx);

const SymbolInfo* currentRelationScope_{nullptr}; // Not owned.
bool processingEmit_{false};
};

} // namespace io::substrait::textplan
2 changes: 2 additions & 0 deletions src/substrait/textplan/parser/SubstraitPlanVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ std::any SubstraitPlanVisitor::visitFile_detail(
item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText()));
} else if (ctx->ORC() != nullptr) {
item->mutable_orc();
} else if (ctx->PARQUET() != nullptr) {
item->mutable_parquet();
} else {
return visitChildren(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PARTITION_INDEX: 'PARTITION_INDEX';
START: 'START';
LENGTH: 'LENGTH';
ORC: 'ORC';
PARQUET: 'PARQUET';
NULLVAL: 'NULL';
TRUEVAL: 'TRUE';
FALSEVAL: 'FALSE';
Expand Down
Loading

0 comments on commit 68036ca

Please sign in to comment.