Skip to content

Commit

Permalink
feat: add binary plan normalizer (#74)
Browse files Browse the repository at this point in the history
The normalizer orders the extension uri spaces and function extensions
alphabetically (and renumbers accordingly) so that it is easier to
compare binary plans. A slight change to the existing sort behavior is
made to be consistent with this ordering.
  • Loading branch information
EpsilonPrime committed Jun 21, 2023
1 parent 390d2dc commit 4601d5c
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 13 deletions.
20 changes: 16 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,15 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
std::stringstream text;

std::map<uint32_t, std::string> spaceNames;
std::set<uint32_t> usedSpaces;
auto cmp = [&](uint32_t a, uint32_t b) {
if (spaceNames.find(a) == spaceNames.end()) {
return spaceNames.find(b) != spaceNames.end();
} else if (spaceNames.find(b) == spaceNames.end()) {
return false;
}
return spaceNames.at(a) < spaceNames.at(b);
};
std::set<uint32_t, decltype(cmp)> usedSpaces(cmp);

// Look at the existing spaces.
for (const SymbolInfo& info : symbolTable) {
Expand All @@ -278,7 +286,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
usedSpaces.insert(extension->extensionUriReference.value());
}

// Finally output the extensions by space in the order they were encountered.
// Finally output the extensions by space in alphabetical order.
bool hasPreviousOutput = false;
for (const uint32_t space : usedSpaces) {
if (hasPreviousOutput) {
Expand All @@ -291,6 +299,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
text << "extension_space " << spaceNames[space] << " {\n";
}

std::vector<std::pair<std::string, std::string>> functionsToOutput;
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kFunction) {
continue;
Expand All @@ -302,8 +311,11 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
continue;
}

text << " function " << functionData->name << " as " << info.name
<< ";\n";
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (const auto& item : functionsToOutput) {
text << " function " << item.second << " as " << item.first << ";\n";
}
text << "}\n";
hasPreviousOutput = true;
Expand Down
7 changes: 7 additions & 0 deletions src/substrait/textplan/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ endif()
add_executable(planconverter Tool.cpp)

target_link_libraries(planconverter substrait_textplan_converter)

set(NORMALIZER_SRCS ReferenceNormalizer.cpp ReferenceNormalizer.h)

add_library(substrait_textplan_normalizer ${NORMALIZER_SRCS})

target_link_libraries(substrait_textplan_normalizer
substrait_textplan_converter)
292 changes: 292 additions & 0 deletions src/substrait/textplan/converter/ReferenceNormalizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/textplan/converter/ReferenceNormalizer.h"

#include <string>

#include "substrait/proto/algebra.pb.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait::textplan {

namespace {

bool compareExtensionFunctions(
const ::substrait::proto::extensions::SimpleExtensionDeclaration& a,
const ::substrait::proto::extensions::SimpleExtensionDeclaration& b) {
// First sort so that extension functions proceed any other kind of extension.
if (a.has_extension_function() && !b.has_extension_function()) {
return true;
} else if (!a.has_extension_function() && b.has_extension_function()) {
// Extension functions always come first.
return false;
} else if (!a.has_extension_function() && !b.has_extension_function()) {
// Both are not extension functions, no difference in ordering.
return false;
}
// Now sort by space.
if (a.extension_function().extension_uri_reference() <
b.extension_function().extension_uri_reference()) {
return true;
} else if (
a.extension_function().extension_uri_reference() >
b.extension_function().extension_uri_reference()) {
return false;
}
// Finally sort by name within a space.
return a.extension_function().name() < b.extension_function().name();
}

void normalizeFunctionsForExpression(
::substrait::proto::Expression* expr,
const std::map<uint32_t, uint32_t>& functionReferenceMapping);

void normalizeFunctionsForArgument(
::substrait::proto::FunctionArgument& argument,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (argument.has_value()) {
normalizeFunctionsForExpression(
argument.mutable_value(), functionReferenceMapping);
}
}

void normalizeFunctionsForMeasure(
::substrait::proto::AggregateRel_Measure& measure,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
measure.mutable_measure()->set_function_reference(
functionReferenceMapping.at(measure.measure().function_reference()));
}

void normalizeFunctionsForExpression(
::substrait::proto::Expression* expr,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (expr->has_scalar_function()) {
expr->mutable_scalar_function()->set_function_reference(
functionReferenceMapping.at(
expr->scalar_function().function_reference()));
for (auto& arg : *expr->mutable_scalar_function()->mutable_arguments()) {
normalizeFunctionsForArgument(arg, functionReferenceMapping);
}
} else if (expr->has_cast()) {
normalizeFunctionsForExpression(
expr->mutable_cast()->mutable_input(), functionReferenceMapping);
} else if (expr->has_if_then()) {
for (auto& ifthen : *expr->mutable_if_then()->mutable_ifs()) {
normalizeFunctionsForExpression(
ifthen.mutable_if_(), functionReferenceMapping);
normalizeFunctionsForExpression(
ifthen.mutable_then(), functionReferenceMapping);
}
if (expr->if_then().has_else_()) {
normalizeFunctionsForExpression(
expr->mutable_if_then()->mutable_else_(), functionReferenceMapping);
}
}
}

void normalizeFunctionsForRelation(
::substrait::proto::Rel* relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation->has_read()) {
if (relation->read().has_filter()) {
normalizeFunctionsForExpression(
relation->mutable_read()->mutable_filter(), functionReferenceMapping);
}
if (relation->read().has_best_effort_filter()) {
normalizeFunctionsForExpression(
relation->mutable_read()->mutable_best_effort_filter(),
functionReferenceMapping);
}
} else if (relation->has_filter()) {
normalizeFunctionsForRelation(
relation->mutable_filter()->mutable_input(), functionReferenceMapping);
if (relation->filter().has_condition()) {
normalizeFunctionsForExpression(
relation->mutable_filter()->mutable_condition(),
functionReferenceMapping);
}
} else if (relation->has_fetch()) {
normalizeFunctionsForRelation(
relation->mutable_fetch()->mutable_input(), functionReferenceMapping);
} else if (relation->has_aggregate()) {
normalizeFunctionsForRelation(
relation->mutable_aggregate()->mutable_input(),
functionReferenceMapping);
for (auto& measure : *relation->mutable_aggregate()->mutable_measures()) {
normalizeFunctionsForMeasure(measure, functionReferenceMapping);
}
} else if (relation->has_sort()) {
normalizeFunctionsForRelation(
relation->mutable_sort()->mutable_input(), functionReferenceMapping);
for (auto& sort : *relation->mutable_sort()->mutable_sorts()) {
normalizeFunctionsForExpression(
sort.mutable_expr(), functionReferenceMapping);
}
} else if (relation->has_join()) {
if (relation->join().has_expression()) {
normalizeFunctionsForExpression(
relation->mutable_join()->mutable_expression(),
functionReferenceMapping);
}
if (relation->join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
normalizeFunctionsForRelation(
relation->mutable_join()->mutable_left(), functionReferenceMapping);
normalizeFunctionsForRelation(
relation->mutable_join()->mutable_right(), functionReferenceMapping);
} else if (relation->has_project()) {
normalizeFunctionsForRelation(
relation->mutable_project()->mutable_input(), functionReferenceMapping);
for (auto& expr : *relation->mutable_project()->mutable_expressions()) {
normalizeFunctionsForExpression(&expr, functionReferenceMapping);
}
} else if (relation->has_set()) {
for (auto& input : *relation->mutable_set()->mutable_inputs()) {
normalizeFunctionsForRelation(&input, functionReferenceMapping);
}
} else if (relation->has_extension_single()) {
if (relation->extension_single().has_input()) {
normalizeFunctionsForRelation(
relation->mutable_extension_single()->mutable_input(),
functionReferenceMapping);
}
} else if (relation->has_extension_multi()) {
for (auto& input : *relation->mutable_extension_multi()->mutable_inputs()) {
normalizeFunctionsForRelation(&input, functionReferenceMapping);
}
} else if (relation->has_extension_leaf()) {
// Nothing to do here.
} else if (relation->has_cross()) {
if (relation->cross().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_cross()->mutable_left(), functionReferenceMapping);
}
if (relation->cross().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_cross()->mutable_right(), functionReferenceMapping);
}
} else if (relation->has_hash_join()) {
if (relation->hash_join().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_hash_join()->mutable_left(),
functionReferenceMapping);
}
if (relation->hash_join().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_hash_join()->mutable_right(),
functionReferenceMapping);
}
if (relation->hash_join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_hash_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
} else if (relation->has_merge_join()) {
if (relation->merge_join().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_merge_join()->mutable_left(),
functionReferenceMapping);
}
if (relation->merge_join().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_merge_join()->mutable_right(),
functionReferenceMapping);
}
if (relation->merge_join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_merge_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
}
}

void normalizeFunctionsForRootRelation(
::substrait::proto::RelRoot* relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation->has_input()) {
normalizeFunctionsForRelation(
relation->mutable_input(), functionReferenceMapping);
}
}

void normalizeFunctionsForPlanRelation(
::substrait::proto::PlanRel& relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation.has_root()) {
normalizeFunctionsForRootRelation(
relation.mutable_root(), functionReferenceMapping);
}
if (relation.has_rel()) {
normalizeFunctionsForRelation(
relation.mutable_rel(), functionReferenceMapping);
}
}

} // namespace

void ReferenceNormalizer::normalizeSpaces(::substrait::proto::Plan* plan) {
std::map<uint32_t, uint32_t> extensionSpaceReferenceMapping;

// Reorder the extension spaces and remember what we changed.
std::sort(
plan->mutable_extension_uris()->begin(),
plan->mutable_extension_uris()->end(),
[](const ::substrait::proto::extensions::SimpleExtensionURI& a,
const ::substrait::proto::extensions::SimpleExtensionURI& b) {
return a.uri() < b.uri();
});

// Now renumber the spaces.
uint32_t uriNum = 0;
for (auto& extensionUri : *plan->mutable_extension_uris()) {
extensionSpaceReferenceMapping[extensionUri.extension_uri_anchor()] =
++uriNum;
extensionUri.set_extension_uri_anchor(uriNum);
}

// Apply the space numbering changes to the functions.
for (auto& function : *plan->mutable_extensions()) {
if (function.has_extension_function()) {
auto newSpace = extensionSpaceReferenceMapping.find(
function.extension_function().extension_uri_reference());
if (newSpace != extensionSpaceReferenceMapping.end()) {
function.mutable_extension_function()->set_extension_uri_reference(
newSpace->second);
}
}
}
}

void ReferenceNormalizer::normalizeFunctions(::substrait::proto::Plan* plan) {
std::map<uint32_t, uint32_t> functionReferenceMapping;

// First sort the functions alphabetically by space.
std::sort(
plan->mutable_extensions()->begin(),
plan->mutable_extensions()->end(),
compareExtensionFunctions);

// Now renumber the functions starting with zero.
uint32_t functionNum = 0;
for (auto& function : *plan->mutable_extensions()) {
functionReferenceMapping[function.extension_function().function_anchor()] =
functionNum;
function.mutable_extension_function()->set_function_anchor(functionNum);
functionNum++;
}

// Now apply that reordering to the rest of the protobuf.
for (auto& relation : *plan->mutable_relations()) {
normalizeFunctionsForPlanRelation(relation, functionReferenceMapping);
}
}

void ReferenceNormalizer::normalize(::substrait::proto::Plan* plan) {
normalizeSpaces(plan);
normalizeFunctions(plan);
}

} // namespace io::substrait::textplan
24 changes: 24 additions & 0 deletions src/substrait/textplan/converter/ReferenceNormalizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include "substrait/proto/plan.pb.h"

namespace io::substrait::textplan {

// ReferenceNormalizer renumbers the extension space uri references
// and function references in a consistent manner. This makes it easier
// for differencing tools to compare two similar binary plans. The behavior
// of this tool is undefined on invalid plans.
class ReferenceNormalizer {
public:
ReferenceNormalizer() = default;

static void normalize(::substrait::proto::Plan* plan);

private:
static void normalizeSpaces(::substrait::proto::Plan* plan);
static void normalizeFunctions(::substrait::proto::Plan* plan);
};

} // namespace io::substrait::textplan
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ std::vector<TestCase> getTestCases() {
})",
WhenSerialized(EqSquashingWhitespace(
R"(extension_space {
function lte:fp64_fp64 as lte;
function sum:fp64_fp64 as sum;
function lt:fp64_fp64 as lt;
function is_not_null:fp64 as is_not_null;
function and:bool_bool as and;
function gte:fp64_fp64 as gte;
function is_not_null:fp64 as is_not_null;
function lt:fp64_fp64 as lt;
function lte:fp64_fp64 as lte;
function multiply:opt_fp64_fp64 as multiply;
function sum:fp64_fp64 as sum;
})")),
},
{
Expand Down
Loading

0 comments on commit 4601d5c

Please sign in to comment.