Skip to content

Commit

Permalink
feat: add binary plan normalizer
Browse files Browse the repository at this point in the history
The normalizer orders the extension uri spaces and function extensions alphabetically (and renumbers accordingly) so that it is easier to compare binary plans.   A slight change to the existing sort behavior is made to be consistent with this ordering.
  • Loading branch information
EpsilonPrime committed Jun 21, 2023
1 parent fee0829 commit 3b84ff5
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 7 deletions.
20 changes: 16 additions & 4 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,15 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
std::stringstream text;

std::map<uint32_t, std::string> spaceNames;
std::set<uint32_t> usedSpaces;
auto cmp = [&](uint32_t a, uint32_t b) {
if (spaceNames.find(a) == spaceNames.end()) {
return spaceNames.find(b) != spaceNames.end();
} else if (spaceNames.find(b) == spaceNames.end()) {
return false;
}
return spaceNames.at(a) < spaceNames.at(b);
};
std::set<uint32_t, decltype(cmp)> usedSpaces(cmp);

// Look at the existing spaces.
for (const SymbolInfo& info : symbolTable) {
Expand All @@ -278,7 +286,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
usedSpaces.insert(extension->extensionUriReference.value());
}

// Finally output the extensions by space in the order they were encountered.
// Finally output the extensions by space in alphabetical order.
bool hasPreviousOutput = false;
for (const uint32_t space : usedSpaces) {
if (hasPreviousOutput) {
Expand All @@ -291,6 +299,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
text << "extension_space " << spaceNames[space] << " {\n";
}

std::vector<std::pair<std::string, std::string>> functionsToOutput;
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kFunction) {
continue;
Expand All @@ -302,8 +311,11 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
continue;
}

text << " function " << functionData->name << " as " << info.name
<< ";\n";
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (const auto& item : functionsToOutput) {
text << " function " << item.second << " as " << item.first << ";\n";
}
text << "}\n";
hasPreviousOutput = true;
Expand Down
10 changes: 10 additions & 0 deletions src/substrait/textplan/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ endif()
add_executable(planconverter Tool.cpp)

target_link_libraries(planconverter substrait_textplan_converter)

set(NORMALIZER_SRCS
ReferenceNormalizer.cpp
ReferenceNormalizer.h)

add_library(substrait_textplan_normalizer ${NORMALIZER_SRCS})

target_link_libraries(
substrait_textplan_normalizer
substrait_textplan_converter)
292 changes: 292 additions & 0 deletions src/substrait/textplan/converter/ReferenceNormalizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "substrait/textplan/converter/ReferenceNormalizer.h"

#include <string>

#include "substrait/proto/algebra.pb.h"
#include "substrait/proto/plan.pb.h"

namespace io::substrait::textplan {

namespace {

bool compareExtensionFunctions(
const ::substrait::proto::extensions::SimpleExtensionDeclaration& a,
const ::substrait::proto::extensions::SimpleExtensionDeclaration& b) {
// First sort so that extension functions proceed any other kind of extension.
if (a.has_extension_function() && !b.has_extension_function()) {
return true;
} else if (!a.has_extension_function() && b.has_extension_function()) {
// Extension functions always come first.
return false;
} else if (!a.has_extension_function() && !b.has_extension_function()) {
// Both are not extension functions, no difference in ordering.
return false;
}
// Now sort by space.
if (a.extension_function().extension_uri_reference() <
b.extension_function().extension_uri_reference()) {
return true;
} else if (
a.extension_function().extension_uri_reference() >
b.extension_function().extension_uri_reference()) {
return false;
}
// Finally sort by name within a space.
return a.extension_function().name() < b.extension_function().name();
}

void normalizeFunctionsForExpression(
::substrait::proto::Expression* expr,
const std::map<uint32_t, uint32_t>& functionReferenceMapping);

void normalizeFunctionsForArgument(
::substrait::proto::FunctionArgument& argument,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (argument.has_value()) {
normalizeFunctionsForExpression(
argument.mutable_value(), functionReferenceMapping);
}
}

void normalizeFunctionsForMeasure(
::substrait::proto::AggregateRel_Measure& measure,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
measure.mutable_measure()->set_function_reference(
functionReferenceMapping.at(measure.measure().function_reference()));
}

void normalizeFunctionsForExpression(
::substrait::proto::Expression* expr,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (expr->has_scalar_function()) {
expr->mutable_scalar_function()->set_function_reference(
functionReferenceMapping.at(
expr->scalar_function().function_reference()));
for (auto& arg : *expr->mutable_scalar_function()->mutable_arguments()) {
normalizeFunctionsForArgument(arg, functionReferenceMapping);
}
} else if (expr->has_cast()) {
normalizeFunctionsForExpression(
expr->mutable_cast()->mutable_input(), functionReferenceMapping);
} else if (expr->has_if_then()) {
for (auto& ifthen : *expr->mutable_if_then()->mutable_ifs()) {
normalizeFunctionsForExpression(
ifthen.mutable_if_(), functionReferenceMapping);
normalizeFunctionsForExpression(
ifthen.mutable_then(), functionReferenceMapping);
}
if (expr->if_then().has_else_()) {
normalizeFunctionsForExpression(
expr->mutable_if_then()->mutable_else_(), functionReferenceMapping);
}
}
}

void normalizeFunctionsForRelation(
::substrait::proto::Rel* relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation->has_read()) {
if (relation->read().has_filter()) {
normalizeFunctionsForExpression(
relation->mutable_read()->mutable_filter(), functionReferenceMapping);
}
if (relation->read().has_best_effort_filter()) {
normalizeFunctionsForExpression(
relation->mutable_read()->mutable_best_effort_filter(),
functionReferenceMapping);
}
} else if (relation->has_filter()) {
normalizeFunctionsForRelation(
relation->mutable_filter()->mutable_input(), functionReferenceMapping);
if (relation->filter().has_condition()) {
normalizeFunctionsForExpression(
relation->mutable_filter()->mutable_condition(),
functionReferenceMapping);
}
} else if (relation->has_fetch()) {
normalizeFunctionsForRelation(
relation->mutable_fetch()->mutable_input(), functionReferenceMapping);
} else if (relation->has_aggregate()) {
normalizeFunctionsForRelation(
relation->mutable_aggregate()->mutable_input(),
functionReferenceMapping);
for (auto& measure : *relation->mutable_aggregate()->mutable_measures()) {
normalizeFunctionsForMeasure(measure, functionReferenceMapping);
}
} else if (relation->has_sort()) {
normalizeFunctionsForRelation(
relation->mutable_sort()->mutable_input(), functionReferenceMapping);
for (auto& sort : *relation->mutable_sort()->mutable_sorts()) {
normalizeFunctionsForExpression(
sort.mutable_expr(), functionReferenceMapping);
}
} else if (relation->has_join()) {
if (relation->join().has_expression()) {
normalizeFunctionsForExpression(
relation->mutable_join()->mutable_expression(),
functionReferenceMapping);
}
if (relation->join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
normalizeFunctionsForRelation(
relation->mutable_join()->mutable_left(), functionReferenceMapping);
normalizeFunctionsForRelation(
relation->mutable_join()->mutable_right(), functionReferenceMapping);
} else if (relation->has_project()) {
normalizeFunctionsForRelation(
relation->mutable_project()->mutable_input(), functionReferenceMapping);
for (auto& expr : *relation->mutable_project()->mutable_expressions()) {
normalizeFunctionsForExpression(&expr, functionReferenceMapping);
}
} else if (relation->has_set()) {
for (auto& input : *relation->mutable_set()->mutable_inputs()) {
normalizeFunctionsForRelation(&input, functionReferenceMapping);
}
} else if (relation->has_extension_single()) {
if (relation->extension_single().has_input()) {
normalizeFunctionsForRelation(
relation->mutable_extension_single()->mutable_input(),
functionReferenceMapping);
}
} else if (relation->has_extension_multi()) {
for (auto& input : *relation->mutable_extension_multi()->mutable_inputs()) {
normalizeFunctionsForRelation(&input, functionReferenceMapping);
}
} else if (relation->has_extension_leaf()) {
// Nothing to do here.
} else if (relation->has_cross()) {
if (relation->cross().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_cross()->mutable_left(), functionReferenceMapping);
}
if (relation->cross().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_cross()->mutable_right(), functionReferenceMapping);
}
} else if (relation->has_hash_join()) {
if (relation->hash_join().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_hash_join()->mutable_left(),
functionReferenceMapping);
}
if (relation->hash_join().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_hash_join()->mutable_right(),
functionReferenceMapping);
}
if (relation->hash_join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_hash_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
} else if (relation->has_merge_join()) {
if (relation->merge_join().has_left()) {
normalizeFunctionsForRelation(
relation->mutable_merge_join()->mutable_left(),
functionReferenceMapping);
}
if (relation->merge_join().has_right()) {
normalizeFunctionsForRelation(
relation->mutable_merge_join()->mutable_right(),
functionReferenceMapping);
}
if (relation->merge_join().has_post_join_filter()) {
normalizeFunctionsForExpression(
relation->mutable_merge_join()->mutable_post_join_filter(),
functionReferenceMapping);
}
}
}

void normalizeFunctionsForRootRelation(
::substrait::proto::RelRoot* relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation->has_input()) {
normalizeFunctionsForRelation(
relation->mutable_input(), functionReferenceMapping);
}
}

void normalizeFunctionsForPlanRelation(
::substrait::proto::PlanRel& relation,
const std::map<uint32_t, uint32_t>& functionReferenceMapping) {
if (relation.has_root()) {
normalizeFunctionsForRootRelation(
relation.mutable_root(), functionReferenceMapping);
}
if (relation.has_rel()) {
normalizeFunctionsForRelation(
relation.mutable_rel(), functionReferenceMapping);
}
}

} // namespace

void ReferenceNormalizer::normalizeSpaces(::substrait::proto::Plan* plan) {
std::map<uint32_t, uint32_t> extensionSpaceReferenceMapping;

// Reorder the extension spaces and remember what we changed.
std::sort(
plan->mutable_extension_uris()->begin(),
plan->mutable_extension_uris()->end(),
[](const ::substrait::proto::extensions::SimpleExtensionURI& a,
const ::substrait::proto::extensions::SimpleExtensionURI& b) {
return a.uri() < b.uri();
});

// Now renumber the spaces.
uint32_t uriNum = 0;
for (auto& extensionUri : *plan->mutable_extension_uris()) {
extensionSpaceReferenceMapping[extensionUri.extension_uri_anchor()] =
++uriNum;
extensionUri.set_extension_uri_anchor(uriNum);
}

// Apply the space numbering changes to the functions.
for (auto& function : *plan->mutable_extensions()) {
if (function.has_extension_function()) {
auto newSpace = extensionSpaceReferenceMapping.find(
function.extension_function().extension_uri_reference());
if (newSpace != extensionSpaceReferenceMapping.end()) {
function.mutable_extension_function()->set_extension_uri_reference(
newSpace->second);
}
}
}
}

void ReferenceNormalizer::normalizeFunctions(::substrait::proto::Plan* plan) {
std::map<uint32_t, uint32_t> functionReferenceMapping;

// First sort the functions alphabetically by space.
std::sort(
plan->mutable_extensions()->begin(),
plan->mutable_extensions()->end(),
compareExtensionFunctions);

// Now renumber the functions starting with zero.
uint32_t functionNum = 0;
for (auto& function : *plan->mutable_extensions()) {
functionReferenceMapping[function.extension_function().function_anchor()] =
functionNum;
function.mutable_extension_function()->set_function_anchor(functionNum);
functionNum++;
}

// Now apply that reordering to the rest of the protobuf.
for (auto& relation : *plan->mutable_relations()) {
normalizeFunctionsForPlanRelation(relation, functionReferenceMapping);
}
}

void ReferenceNormalizer::normalize(::substrait::proto::Plan* plan) {
normalizeSpaces(plan);
normalizeFunctions(plan);
}

} // namespace io::substrait::textplan
24 changes: 24 additions & 0 deletions src/substrait/textplan/converter/ReferenceNormalizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include "substrait/proto/plan.pb.h"

namespace io::substrait::textplan {

// ReferenceNormalizer renumbers the extension space uri references
// and function references in a consistent manner. This makes it easier
// for differencing tools to compare two similar binary plans. The behavior
// of this tool is undefined on invalid plans.
class ReferenceNormalizer {
public:
ReferenceNormalizer() = default;

static void normalize(::substrait::proto::Plan* plan);

private:
static void normalizeSpaces(::substrait::proto::Plan* plan);
static void normalizeFunctions(::substrait::proto::Plan* plan);
};

} // namespace io::substrait::textplan
2 changes: 1 addition & 1 deletion src/substrait/textplan/parser/tests/TextPlanParserTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ std::vector<TestCase> getTestCases() {
extension_space blah.yaml {
function add:i8 as add;
function subtract:i8 as subtract;
function concat:str as concat;
function subtract:i8 as subtract;
})")),
AsBinaryPlan(EqualsProto<::substrait::proto::Plan>(
R"(extension_uris {
Expand Down
1 change: 1 addition & 0 deletions src/substrait/textplan/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ if(${SUBSTRAIT_CPP_ROUNDTRIP_TESTING})
EXTRA_LINK_LIBS
substrait_textplan_converter
substrait_textplan_loader
substrait_textplan_normalizer
substrait_common
substrait_proto
parse_result_matchers
Expand Down
Loading

0 comments on commit 3b84ff5

Please sign in to comment.