diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index 9e87ad9c..121274a6 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -254,7 +254,15 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { std::stringstream text; std::map spaceNames; - std::set usedSpaces; + auto cmp = [&](uint32_t a, uint32_t b) { + if (spaceNames.find(a) == spaceNames.end()) { + return spaceNames.find(b) != spaceNames.end(); + } else if (spaceNames.find(b) == spaceNames.end()) { + return false; + } + return spaceNames.at(a) < spaceNames.at(b); + }; + std::set usedSpaces(cmp); // Look at the existing spaces. for (const SymbolInfo& info : symbolTable) { @@ -278,7 +286,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { usedSpaces.insert(extension->extensionUriReference.value()); } - // Finally output the extensions by space in the order they were encountered. + // Finally output the extensions by space in alphabetical order. bool hasPreviousOutput = false; for (const uint32_t space : usedSpaces) { if (hasPreviousOutput) { @@ -291,6 +299,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { text << "extension_space " << spaceNames[space] << " {\n"; } + std::vector> functionsToOutput; for (const SymbolInfo& info : symbolTable) { if (info.type != SymbolType::kFunction) { continue; @@ -302,8 +311,11 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { continue; } - text << " function " << functionData->name << " as " << info.name - << ";\n"; + functionsToOutput.emplace_back(info.name, functionData->name); + } + std::sort(functionsToOutput.begin(), functionsToOutput.end()); + for (const auto& item : functionsToOutput) { + text << " function " << item.second << " as " << item.first << ";\n"; } text << "}\n"; hasPreviousOutput = true; diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 44594716..86caf620 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -32,3 +32,13 @@ endif() add_executable(planconverter Tool.cpp) target_link_libraries(planconverter substrait_textplan_converter) + +set(NORMALIZER_SRCS + ReferenceNormalizer.cpp + ReferenceNormalizer.h) + +add_library(substrait_textplan_normalizer ${NORMALIZER_SRCS}) + +target_link_libraries( + substrait_textplan_normalizer + substrait_textplan_converter) diff --git a/src/substrait/textplan/converter/ReferenceNormalizer.cpp b/src/substrait/textplan/converter/ReferenceNormalizer.cpp new file mode 100644 index 00000000..38c4d209 --- /dev/null +++ b/src/substrait/textplan/converter/ReferenceNormalizer.cpp @@ -0,0 +1,292 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/textplan/converter/ReferenceNormalizer.h" + +#include + +#include "substrait/proto/algebra.pb.h" +#include "substrait/proto/plan.pb.h" + +namespace io::substrait::textplan { + +namespace { + +bool compareExtensionFunctions( + const ::substrait::proto::extensions::SimpleExtensionDeclaration& a, + const ::substrait::proto::extensions::SimpleExtensionDeclaration& b) { + // First sort so that extension functions proceed any other kind of extension. + if (a.has_extension_function() && !b.has_extension_function()) { + return true; + } else if (!a.has_extension_function() && b.has_extension_function()) { + // Extension functions always come first. + return false; + } else if (!a.has_extension_function() && !b.has_extension_function()) { + // Both are not extension functions, no difference in ordering. + return false; + } + // Now sort by space. + if (a.extension_function().extension_uri_reference() < + b.extension_function().extension_uri_reference()) { + return true; + } else if ( + a.extension_function().extension_uri_reference() > + b.extension_function().extension_uri_reference()) { + return false; + } + // Finally sort by name within a space. + return a.extension_function().name() < b.extension_function().name(); +} + +void normalizeFunctionsForExpression( + ::substrait::proto::Expression* expr, + const std::map& functionReferenceMapping); + +void normalizeFunctionsForArgument( + ::substrait::proto::FunctionArgument& argument, + const std::map& functionReferenceMapping) { + if (argument.has_value()) { + normalizeFunctionsForExpression( + argument.mutable_value(), functionReferenceMapping); + } +} + +void normalizeFunctionsForMeasure( + ::substrait::proto::AggregateRel_Measure& measure, + const std::map& functionReferenceMapping) { + measure.mutable_measure()->set_function_reference( + functionReferenceMapping.at(measure.measure().function_reference())); +} + +void normalizeFunctionsForExpression( + ::substrait::proto::Expression* expr, + const std::map& functionReferenceMapping) { + if (expr->has_scalar_function()) { + expr->mutable_scalar_function()->set_function_reference( + functionReferenceMapping.at( + expr->scalar_function().function_reference())); + for (auto& arg : *expr->mutable_scalar_function()->mutable_arguments()) { + normalizeFunctionsForArgument(arg, functionReferenceMapping); + } + } else if (expr->has_cast()) { + normalizeFunctionsForExpression( + expr->mutable_cast()->mutable_input(), functionReferenceMapping); + } else if (expr->has_if_then()) { + for (auto& ifthen : *expr->mutable_if_then()->mutable_ifs()) { + normalizeFunctionsForExpression( + ifthen.mutable_if_(), functionReferenceMapping); + normalizeFunctionsForExpression( + ifthen.mutable_then(), functionReferenceMapping); + } + if (expr->if_then().has_else_()) { + normalizeFunctionsForExpression( + expr->mutable_if_then()->mutable_else_(), functionReferenceMapping); + } + } +} + +void normalizeFunctionsForRelation( + ::substrait::proto::Rel* relation, + const std::map& functionReferenceMapping) { + if (relation->has_read()) { + if (relation->read().has_filter()) { + normalizeFunctionsForExpression( + relation->mutable_read()->mutable_filter(), functionReferenceMapping); + } + if (relation->read().has_best_effort_filter()) { + normalizeFunctionsForExpression( + relation->mutable_read()->mutable_best_effort_filter(), + functionReferenceMapping); + } + } else if (relation->has_filter()) { + normalizeFunctionsForRelation( + relation->mutable_filter()->mutable_input(), functionReferenceMapping); + if (relation->filter().has_condition()) { + normalizeFunctionsForExpression( + relation->mutable_filter()->mutable_condition(), + functionReferenceMapping); + } + } else if (relation->has_fetch()) { + normalizeFunctionsForRelation( + relation->mutable_fetch()->mutable_input(), functionReferenceMapping); + } else if (relation->has_aggregate()) { + normalizeFunctionsForRelation( + relation->mutable_aggregate()->mutable_input(), + functionReferenceMapping); + for (auto& measure : *relation->mutable_aggregate()->mutable_measures()) { + normalizeFunctionsForMeasure(measure, functionReferenceMapping); + } + } else if (relation->has_sort()) { + normalizeFunctionsForRelation( + relation->mutable_sort()->mutable_input(), functionReferenceMapping); + for (auto& sort : *relation->mutable_sort()->mutable_sorts()) { + normalizeFunctionsForExpression( + sort.mutable_expr(), functionReferenceMapping); + } + } else if (relation->has_join()) { + if (relation->join().has_expression()) { + normalizeFunctionsForExpression( + relation->mutable_join()->mutable_expression(), + functionReferenceMapping); + } + if (relation->join().has_post_join_filter()) { + normalizeFunctionsForExpression( + relation->mutable_join()->mutable_post_join_filter(), + functionReferenceMapping); + } + normalizeFunctionsForRelation( + relation->mutable_join()->mutable_left(), functionReferenceMapping); + normalizeFunctionsForRelation( + relation->mutable_join()->mutable_right(), functionReferenceMapping); + } else if (relation->has_project()) { + normalizeFunctionsForRelation( + relation->mutable_project()->mutable_input(), functionReferenceMapping); + for (auto& expr : *relation->mutable_project()->mutable_expressions()) { + normalizeFunctionsForExpression(&expr, functionReferenceMapping); + } + } else if (relation->has_set()) { + for (auto& input : *relation->mutable_set()->mutable_inputs()) { + normalizeFunctionsForRelation(&input, functionReferenceMapping); + } + } else if (relation->has_extension_single()) { + if (relation->extension_single().has_input()) { + normalizeFunctionsForRelation( + relation->mutable_extension_single()->mutable_input(), + functionReferenceMapping); + } + } else if (relation->has_extension_multi()) { + for (auto& input : *relation->mutable_extension_multi()->mutable_inputs()) { + normalizeFunctionsForRelation(&input, functionReferenceMapping); + } + } else if (relation->has_extension_leaf()) { + // Nothing to do here. + } else if (relation->has_cross()) { + if (relation->cross().has_left()) { + normalizeFunctionsForRelation( + relation->mutable_cross()->mutable_left(), functionReferenceMapping); + } + if (relation->cross().has_right()) { + normalizeFunctionsForRelation( + relation->mutable_cross()->mutable_right(), functionReferenceMapping); + } + } else if (relation->has_hash_join()) { + if (relation->hash_join().has_left()) { + normalizeFunctionsForRelation( + relation->mutable_hash_join()->mutable_left(), + functionReferenceMapping); + } + if (relation->hash_join().has_right()) { + normalizeFunctionsForRelation( + relation->mutable_hash_join()->mutable_right(), + functionReferenceMapping); + } + if (relation->hash_join().has_post_join_filter()) { + normalizeFunctionsForExpression( + relation->mutable_hash_join()->mutable_post_join_filter(), + functionReferenceMapping); + } + } else if (relation->has_merge_join()) { + if (relation->merge_join().has_left()) { + normalizeFunctionsForRelation( + relation->mutable_merge_join()->mutable_left(), + functionReferenceMapping); + } + if (relation->merge_join().has_right()) { + normalizeFunctionsForRelation( + relation->mutable_merge_join()->mutable_right(), + functionReferenceMapping); + } + if (relation->merge_join().has_post_join_filter()) { + normalizeFunctionsForExpression( + relation->mutable_merge_join()->mutable_post_join_filter(), + functionReferenceMapping); + } + } +} + +void normalizeFunctionsForRootRelation( + ::substrait::proto::RelRoot* relation, + const std::map& functionReferenceMapping) { + if (relation->has_input()) { + normalizeFunctionsForRelation( + relation->mutable_input(), functionReferenceMapping); + } +} + +void normalizeFunctionsForPlanRelation( + ::substrait::proto::PlanRel& relation, + const std::map& functionReferenceMapping) { + if (relation.has_root()) { + normalizeFunctionsForRootRelation( + relation.mutable_root(), functionReferenceMapping); + } + if (relation.has_rel()) { + normalizeFunctionsForRelation( + relation.mutable_rel(), functionReferenceMapping); + } +} + +} // namespace + +void ReferenceNormalizer::normalizeSpaces(::substrait::proto::Plan* plan) { + std::map extensionSpaceReferenceMapping; + + // Reorder the extension spaces and remember what we changed. + std::sort( + plan->mutable_extension_uris()->begin(), + plan->mutable_extension_uris()->end(), + [](const ::substrait::proto::extensions::SimpleExtensionURI& a, + const ::substrait::proto::extensions::SimpleExtensionURI& b) { + return a.uri() < b.uri(); + }); + + // Now renumber the spaces. + uint32_t uriNum = 0; + for (auto& extensionUri : *plan->mutable_extension_uris()) { + extensionSpaceReferenceMapping[extensionUri.extension_uri_anchor()] = + ++uriNum; + extensionUri.set_extension_uri_anchor(uriNum); + } + + // Apply the space numbering changes to the functions. + for (auto& function : *plan->mutable_extensions()) { + if (function.has_extension_function()) { + auto newSpace = extensionSpaceReferenceMapping.find( + function.extension_function().extension_uri_reference()); + if (newSpace != extensionSpaceReferenceMapping.end()) { + function.mutable_extension_function()->set_extension_uri_reference( + newSpace->second); + } + } + } +} + +void ReferenceNormalizer::normalizeFunctions(::substrait::proto::Plan* plan) { + std::map functionReferenceMapping; + + // First sort the functions alphabetically by space. + std::sort( + plan->mutable_extensions()->begin(), + plan->mutable_extensions()->end(), + compareExtensionFunctions); + + // Now renumber the functions starting with zero. + uint32_t functionNum = 0; + for (auto& function : *plan->mutable_extensions()) { + functionReferenceMapping[function.extension_function().function_anchor()] = + functionNum; + function.mutable_extension_function()->set_function_anchor(functionNum); + functionNum++; + } + + // Now apply that reordering to the rest of the protobuf. + for (auto& relation : *plan->mutable_relations()) { + normalizeFunctionsForPlanRelation(relation, functionReferenceMapping); + } +} + +void ReferenceNormalizer::normalize(::substrait::proto::Plan* plan) { + normalizeSpaces(plan); + normalizeFunctions(plan); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/ReferenceNormalizer.h b/src/substrait/textplan/converter/ReferenceNormalizer.h new file mode 100644 index 00000000..fbf722d7 --- /dev/null +++ b/src/substrait/textplan/converter/ReferenceNormalizer.h @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "substrait/proto/plan.pb.h" + +namespace io::substrait::textplan { + +// ReferenceNormalizer renumbers the extension space uri references +// and function references in a consistent manner. This makes it easier +// for differencing tools to compare two similar binary plans. The behavior +// of this tool is undefined on invalid plans. +class ReferenceNormalizer { + public: + ReferenceNormalizer() = default; + + static void normalize(::substrait::proto::Plan* plan); + + private: + static void normalizeSpaces(::substrait::proto::Plan* plan); + static void normalizeFunctions(::substrait::proto::Plan* plan); +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index 13e8eef6..d6c0f169 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -228,8 +228,8 @@ std::vector getTestCases() { extension_space blah.yaml { function add:i8 as add; - function subtract:i8 as subtract; function concat:str as concat; + function subtract:i8 as subtract; })")), AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(extension_uris { diff --git a/src/substrait/textplan/tests/CMakeLists.txt b/src/substrait/textplan/tests/CMakeLists.txt index 40f48cec..16da68ba 100644 --- a/src/substrait/textplan/tests/CMakeLists.txt +++ b/src/substrait/textplan/tests/CMakeLists.txt @@ -30,6 +30,7 @@ if(${SUBSTRAIT_CPP_ROUNDTRIP_TESTING}) EXTRA_LINK_LIBS substrait_textplan_converter substrait_textplan_loader + substrait_textplan_normalizer substrait_common substrait_proto parse_result_matchers diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 57680238..c5190b65 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -15,6 +15,7 @@ #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" +#include "substrait/textplan/converter/ReferenceNormalizer.h" #include "substrait/textplan/parser/ParseText.h" #include "substrait/textplan/tests/ParseResultMatchers.h" @@ -35,6 +36,13 @@ std::string addLineNumbers(const std::string& text) { return result.str(); } +::substrait::proto::Plan normalizePlan(const ::substrait::proto::Plan& plan) { + ::substrait::proto::Plan newPlan = plan; + ReferenceNormalizer normalizer; + normalizer.normalize(&newPlan); + return newPlan; +} + class RoundTripBinaryToTextFixture : public ::testing::TestWithParam {}; @@ -74,13 +82,13 @@ TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { ASSERT_NO_THROW(auto outputBinary = SymbolTablePrinter::outputToBinaryPlan( result.getSymbolTable());); + auto normalizedPlan = normalizePlan(plan); ASSERT_THAT( result, ::testing::AllOf( ParsesOk(), HasErrors({}), - AsBinaryPlan(IgnoringFieldPaths( - {"extension_uris", "extensions"}, EqualsProto(plan))))) + AsBinaryPlan(EqualsProto(normalizedPlan)))) << std::endl << "Intermediate result:" << std::endl << addLineNumbers(outputText);