From 390d2dc556dab43e1ab3b0b4b36a7ba3986cac19 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 21 Jun 2023 12:02:46 -0700 Subject: [PATCH] feat: add support for enum arguments in textplans (#73) --- .../textplan/converter/PlanPrinterVisitor.cpp | 18 +++++++++++------- .../parser/SubstraitPlanRelationVisitor.cpp | 19 +++++++++++++++++-- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp index 0b59c983..89577819 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp @@ -59,6 +59,12 @@ std::string invocationToString( return "unspecified"; } +std::string visitEnumArgument(const std::string& str) { + std::stringstream text; + text << str << "_enum"; + return text.str(); +} + } // namespace std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { @@ -393,10 +399,7 @@ std::any PlanPrinterVisitor::visitScalarFunction( } switch (arg.arg_type_case()) { case ::substrait::proto::FunctionArgument::kEnum: - errorListener_->addError( - "Enum arguments not yet supported in scalar functions: " + - arg.ShortDebugString()); - text << "ENUM_NOT_SUPPORTED"; + text << visitEnumArgument(arg.enum_()); break; case ::substrait::proto::FunctionArgument::kType: text << ANY_CAST(std::string, visitType(arg.type())); @@ -523,8 +526,9 @@ std::any PlanPrinterVisitor::visitNested( std::any PlanPrinterVisitor::visitEnum( const ::substrait::proto::Expression_Enum& value) { errorListener_->addError( - "Enum expressions are not yet supported: " + value.ShortDebugString()); - return std::string("ENUM_NOT_YET_IMPLEMENTED"); + "Enum expressions are deprecated and not supported: " + + value.ShortDebugString()); + return std::string("ENUM_EXPRESSION_DEPRECATED"); } std::any PlanPrinterVisitor::visitStructSelect( @@ -592,7 +596,7 @@ std::any PlanPrinterVisitor::visitAggregateFunction( } switch (arg.arg_type_case()) { case ::substrait::proto::FunctionArgument::kEnum: - text << "ENUM_NOT_SUPPORTED"; + text << visitEnumArgument(arg.enum_()); break; case ::substrait::proto::FunctionArgument::kType: text << ANY_CAST(std::string, visitType(arg.type())); diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index 4e2c7bf5..471c1b29 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -12,6 +12,7 @@ #include "SubstraitPlanTypeVisitor.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" +#include "absl/strings/strip.h" #include "date/tz.h" #include "substrait/expression/DecimalLiteral.h" #include "substrait/proto/algebra.pb.h" @@ -45,8 +46,15 @@ std::string toLower(const std::string& str) { } // Yields true if the string 'haystack' starts with the string 'needle'. -bool startsWith(const std::string& haystack, std::string_view needle) { - return strncmp(haystack.c_str(), needle.data(), needle.size()) == 0; +bool startsWith(std::string_view haystack, std::string_view needle) { + return haystack.size() > needle.size() && + haystack.substr(0, needle.size()) == needle; +} + +// Returns true if the string 'haystack' ends with the string 'needle'. +bool endsWith(std::string_view haystack, std::string_view needle) { + return haystack.size() > needle.size() && + haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; } void setNullable(::substrait::proto::Type* type) { @@ -778,6 +786,13 @@ std::any SubstraitPlanRelationVisitor::visitExpressionFunctionUse( expr.mutable_scalar_function()->set_function_reference(funcReference); for (const auto& exp : ctx->expression()) { + if (endsWith(exp->getText(), "_enum")) { + auto str = exp->getText(); + str = absl::StripSuffix(str, "_enum"); + expr.mutable_scalar_function()->add_arguments()->set_enum_(str); + continue; + } + auto result = visitExpression(exp); if (result.type() != typeid(::substrait::proto::Expression)) { errorListener_->addError(