Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add binary plan normalizer #74

Merged
merged 3 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,15 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
std::stringstream text;

std::map<uint32_t, std::string> spaceNames;
std::set<uint32_t> usedSpaces;
auto cmp = [&](uint32_t a, uint32_t b) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a small comment just explaining why we are doing this

if (spaceNames.find(a) == spaceNames.end()) {
return spaceNames.find(b) != spaceNames.end();
} else if (spaceNames.find(b) == spaceNames.end()) {
return false;
}
return spaceNames.at(a) < spaceNames.at(b);
};
std::set<uint32_t, decltype(cmp)> usedSpaces(cmp);

// Look at the existing spaces.
for (const SymbolInfo& info : symbolTable) {
Expand All @@ -278,7 +286,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
usedSpaces.insert(extension->extensionUriReference.value());
}

// Finally output the extensions by space in the order they were encountered.
// Finally output the extensions by space in alphabetical order.
bool hasPreviousOutput = false;
for (const uint32_t space : usedSpaces) {
if (hasPreviousOutput) {
Expand All @@ -291,6 +299,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
text << "extension_space " << spaceNames[space] << " {\n";
}

std::vector<std::pair<std::string, std::string>> functionsToOutput;
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kFunction) {
continue;
Expand All @@ -302,8 +311,11 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
continue;
}

text << " function " << functionData->name << " as " << info.name
<< ";\n";
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (const auto& item : functionsToOutput) {
text << " function " << item.second << " as " << item.first << ";\n";
}
text << "}\n";
hasPreviousOutput = true;
Expand Down
7 changes: 7 additions & 0 deletions src/substrait/textplan/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ endif()
add_executable(planconverter Tool.cpp)

target_link_libraries(planconverter substrait_textplan_converter)

set(NORMALIZER_SRCS ReferenceNormalizer.cpp ReferenceNormalizer.h)

add_library(substrait_textplan_normalizer ${NORMALIZER_SRCS})

target_link_libraries(substrait_textplan_normalizer
substrait_textplan_converter)
292 changes: 292 additions & 0 deletions src/substrait/textplan/converter/ReferenceNormalizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/textplan/converter/ReferenceNormalizer.h"

#include <string>

#include "substrait/proto/algebra.pb.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait::textplan {

namespace {

bool compareExtensionFunctions(
const ::substrait::proto::extensions::SimpleExtensionDeclaration& a,
const ::substrait::proto::extensions::SimpleExtensionDeclaration& b) {
// First sort so that extension functions proceed any other kind of extension.
if (a.has_extension_function() && !b.has_extension_function()) {
return true;
} else if (!a.has_extension_function() && b.has_extension_function()) {
// Extension functions always come first.
return false;
} else if (!a.has_extension_function() && !b.has_extension_function()) {
// Both are not extension functions, no difference in ordering.
return false;
}
// Now sort by space.
if (a.extension_function().extension_uri_reference() <
b.extension_function().extension_uri_reference()) {
return true;
} else if (
a.extension_function().extension_uri_reference() >
b.extension_function().extension_uri_reference()) {
return false;
}
// Finally sort by name within a space.
return a.extension_function().name() < b.extension_function().name();
}

void normalizeFunctionsForExpression(
::substrait::proto::Expression* expr,
const std::map<uint32_t, uint32_t>& functionReferenceMapping);

void normalizeFunctionsForArgument(
::substrait::proto::FunctionArgument& argument,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (argument.has_value()) {
normalizeFunctionsForExpression(
argument.mutable_value(), functionReferenceMapping);
}
}

void normalizeFunctionsForMeasure(
::substrait::proto::AggregateRel_Measure& measure,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
measure.mutable_measure()->set_function_reference(
functionReferenceMapping.at(measure.measure().function_reference()));
}

void normalizeFunctionsForExpression(
::substrait::proto::Expression* expr,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (expr->has_scalar_function()) {
expr->mutable_scalar_function()->set_function_reference(
functionReferenceMapping.at(
expr->scalar_function().function_reference()));
for (auto& arg : *expr->mutable_scalar_function()->mutable_arguments()) {
normalizeFunctionsForArgument(arg, functionReferenceMapping);
}
} else if (expr->has_cast()) {
normalizeFunctionsForExpression(
expr->mutable_cast()->mutable_input(), functionReferenceMapping);
} else if (expr->has_if_then()) {
for (auto& ifthen : *expr->mutable_if_then()->mutable_ifs()) {
normalizeFunctionsForExpression(
ifthen.mutable_if_(), functionReferenceMapping);
normalizeFunctionsForExpression(
ifthen.mutable_then(), functionReferenceMapping);
}
if (expr->if_then().has_else_()) {
normalizeFunctionsForExpression(
expr->mutable_if_then()->mutable_else_(), functionReferenceMapping);
}
}
}

void normalizeFunctionsForRelation(
::substrait::proto::Rel* relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation->has_read()) {
if (relation->read().has_filter()) {
normalizeFunctionsForExpression(
relation->mutable_read()->mutable_filter(), functionReferenceMapping);
}
if (relation->read().has_best_effort_filter()) {
normalizeFunctionsForExpression(
relation->mutable_read()->mutable_best_effort_filter(),
functionReferenceMapping);
}
} else if (relation->has_filter()) {
normalizeFunctionsForRelation(
relation->mutable_filter()->mutable_input(), functionReferenceMapping);
if (relation->filter().has_condition()) {
normalizeFunctionsForExpression(
relation->mutable_filter()->mutable_condition(),
functionReferenceMapping);
}
} else if (relation->has_fetch()) {
normalizeFunctionsForRelation(
relation->mutable_fetch()->mutable_input(), functionReferenceMapping);
} else if (relation->has_aggregate()) {
normalizeFunctionsForRelation(
relation->mutable_aggregate()->mutable_input(),
functionReferenceMapping);
for (auto& measure : *relation->mutable_aggregate()->mutable_measures()) {
normalizeFunctionsForMeasure(measure, functionReferenceMapping);
}
} else if (relation->has_sort()) {
normalizeFunctionsForRelation(
relation->mutable_sort()->mutable_input(), functionReferenceMapping);
for (auto& sort : *relation->mutable_sort()->mutable_sorts()) {
normalizeFunctionsForExpression(
sort.mutable_expr(), functionReferenceMapping);
}
} else if (relation->has_join()) {
if (relation->join().has_expression()) {
normalizeFunctionsForExpression(
relation->mutable_join()->mutable_expression(),
functionReferenceMapping);
}
if (relation->join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
normalizeFunctionsForRelation(
relation->mutable_join()->mutable_left(), functionReferenceMapping);
normalizeFunctionsForRelation(
relation->mutable_join()->mutable_right(), functionReferenceMapping);
} else if (relation->has_project()) {
normalizeFunctionsForRelation(
relation->mutable_project()->mutable_input(), functionReferenceMapping);
for (auto& expr : *relation->mutable_project()->mutable_expressions()) {
normalizeFunctionsForExpression(&expr, functionReferenceMapping);
}
} else if (relation->has_set()) {
for (auto& input : *relation->mutable_set()->mutable_inputs()) {
normalizeFunctionsForRelation(&input, functionReferenceMapping);
}
} else if (relation->has_extension_single()) {
if (relation->extension_single().has_input()) {
normalizeFunctionsForRelation(
relation->mutable_extension_single()->mutable_input(),
functionReferenceMapping);
}
} else if (relation->has_extension_multi()) {
for (auto& input : *relation->mutable_extension_multi()->mutable_inputs()) {
normalizeFunctionsForRelation(&input, functionReferenceMapping);
}
} else if (relation->has_extension_leaf()) {
// Nothing to do here.
} else if (relation->has_cross()) {
if (relation->cross().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_cross()->mutable_left(), functionReferenceMapping);
}
if (relation->cross().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_cross()->mutable_right(), functionReferenceMapping);
}
} else if (relation->has_hash_join()) {
if (relation->hash_join().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_hash_join()->mutable_left(),
functionReferenceMapping);
}
if (relation->hash_join().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_hash_join()->mutable_right(),
functionReferenceMapping);
}
if (relation->hash_join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_hash_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
} else if (relation->has_merge_join()) {
if (relation->merge_join().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_merge_join()->mutable_left(),
functionReferenceMapping);
}
if (relation->merge_join().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_merge_join()->mutable_right(),
functionReferenceMapping);
}
if (relation->merge_join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_merge_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
}
}

void normalizeFunctionsForRootRelation(
::substrait::proto::RelRoot* relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation->has_input()) {
normalizeFunctionsForRelation(
relation->mutable_input(), functionReferenceMapping);
}
}

void normalizeFunctionsForPlanRelation(
::substrait::proto::PlanRel& relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation.has_root()) {
normalizeFunctionsForRootRelation(
relation.mutable_root(), functionReferenceMapping);
}
if (relation.has_rel()) {
normalizeFunctionsForRelation(
relation.mutable_rel(), functionReferenceMapping);
}
}

} // namespace

void ReferenceNormalizer::normalizeSpaces(::substrait::proto::Plan* plan) {
std::map<uint32_t, uint32_t> extensionSpaceReferenceMapping;

// Reorder the extension spaces and remember what we changed.
std::sort(
plan->mutable_extension_uris()->begin(),
plan->mutable_extension_uris()->end(),
[](const ::substrait::proto::extensions::SimpleExtensionURI& a,
const ::substrait::proto::extensions::SimpleExtensionURI& b) {
return a.uri() < b.uri();
});

// Now renumber the spaces.
uint32_t uriNum = 0;
for (auto& extensionUri : *plan->mutable_extension_uris()) {
extensionSpaceReferenceMapping[extensionUri.extension_uri_anchor()] =
++uriNum;
extensionUri.set_extension_uri_anchor(uriNum);
}

// Apply the space numbering changes to the functions.
for (auto& function : *plan->mutable_extensions()) {
if (function.has_extension_function()) {
auto newSpace = extensionSpaceReferenceMapping.find(
function.extension_function().extension_uri_reference());
if (newSpace != extensionSpaceReferenceMapping.end()) {
function.mutable_extension_function()->set_extension_uri_reference(
newSpace->second);
}
}
}
}

void ReferenceNormalizer::normalizeFunctions(::substrait::proto::Plan* plan) {
std::map<uint32_t, uint32_t> functionReferenceMapping;

// First sort the functions alphabetically by space.
std::sort(
plan->mutable_extensions()->begin(),
plan->mutable_extensions()->end(),
compareExtensionFunctions);

// Now renumber the functions starting with zero.
uint32_t functionNum = 0;
for (auto& function : *plan->mutable_extensions()) {
functionReferenceMapping[function.extension_function().function_anchor()] =
functionNum;
function.mutable_extension_function()->set_function_anchor(functionNum);
functionNum++;
}

// Now apply that reordering to the rest of the protobuf.
for (auto& relation : *plan->mutable_relations()) {
normalizeFunctionsForPlanRelation(relation, functionReferenceMapping);
}
}

void ReferenceNormalizer::normalize(::substrait::proto::Plan* plan) {
normalizeSpaces(plan);
normalizeFunctions(plan);
}

} // namespace io::substrait::textplan
24 changes: 24 additions & 0 deletions src/substrait/textplan/converter/ReferenceNormalizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include "substrait/proto/plan.pb.h"

namespace io::substrait::textplan {

// ReferenceNormalizer renumbers the extension space uri references
// and function references in a consistent manner. This makes it easier
// for differencing tools to compare two similar binary plans. The behavior
// of this tool is undefined on invalid plans.
class ReferenceNormalizer {
public:
ReferenceNormalizer() = default;

static void normalize(::substrait::proto::Plan* plan);

private:
static void normalizeSpaces(::substrait::proto::Plan* plan);
static void normalizeFunctions(::substrait::proto::Plan* plan);
};

} // namespace io::substrait::textplan
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ std::vector<TestCase> getTestCases() {
})",
WhenSerialized(EqSquashingWhitespace(
R"(extension_space {
function lte:fp64_fp64 as lte;
function sum:fp64_fp64 as sum;
function lt:fp64_fp64 as lt;
function is_not_null:fp64 as is_not_null;
function and:bool_bool as and;
function gte:fp64_fp64 as gte;
function is_not_null:fp64 as is_not_null;
function lt:fp64_fp64 as lt;
function lte:fp64_fp64 as lte;
function multiply:opt_fp64_fp64 as multiply;
function sum:fp64_fp64 as sum;
})")),
},
{
Expand Down
Loading