diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp index dd6e869c..0b59c983 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp @@ -442,10 +442,33 @@ std::any PlanPrinterVisitor::visitWindowFunction( std::any PlanPrinterVisitor::visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen) { - errorListener_->addError( - "If then expressions are not yet supported: " + - ifthen.ShortDebugString()); - return std::string("IFTHEN_NOT_YET_IMPLEMENTED"); + std::stringstream text; + text << "IFTHEN("; + bool hasPreviousText = false; + for (const auto& clause : ifthen.ifs()) { + if (!clause.has_if_() || !clause.has_then()) { + errorListener_->addError( + "If then clauses require both an if and a then expression: " + + clause.ShortDebugString()); + continue; + } + if (hasPreviousText) { + text << ", "; + } + text << ANY_CAST(std::string, visitExpression(clause.if_())); + text << ", "; + text << ANY_CAST(std::string, visitExpression(clause.then())); + hasPreviousText = true; + } + if (ifthen.has_else_()) { + if (hasPreviousText) { + text << ", "; + } + text << ANY_CAST(std::string, visitExpression(ifthen.else_())); + } + + text << ")"; + return text.str(); } std::any PlanPrinterVisitor::visitSwitchExpression( diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index 9bc1f689..700c7307 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -464,6 +464,38 @@ std::vector getTestCases() { filter functionref#4(field#2, 0.07_fp64); })"))), }, + { + "ifthen expression missing then", + R"(relations: { + root: { + input: { + filter: { + condition: { + if_then: { + ifs: { + if: { + literal: { + nullable: false, + fp64: 0.07 + } + } + } + else: { + literal: { + nullable: false, + fp64: 0.07 + } + } + } + } + } + } + } + })", + HasErrors( + {"If then clauses require both an if and a then expression: " + "if { literal { fp64: 0.07 } }"}), + }, { "cast expression", R"(relations: { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index fc8ee51d..4e2c7bf5 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -10,6 +10,7 @@ #include "SubstraitPlanParser/SubstraitPlanParser.h" #include "SubstraitPlanTypeVisitor.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "date/tz.h" #include "substrait/expression/DecimalLiteral.h" @@ -731,10 +732,39 @@ std::any SubstraitPlanRelationVisitor::visitExpression( return defaultResult(); } +::substrait::proto::Expression +SubstraitPlanRelationVisitor::visitExpressionIfThenUse( + SubstraitPlanParser::ExpressionFunctionUseContext* ctx) { + ::substrait::proto::Expression expr; + size_t currExprNum = 0; + size_t totalExprCount = ctx->expression().size(); + while (currExprNum + 2 <= totalExprCount) { + // Peel off an if/then pair. + auto ifThen = expr.mutable_if_then()->add_ifs(); + *ifThen->mutable_if_() = ANY_CAST( + ::substrait::proto::Expression, + visitExpression(ctx->expression(currExprNum))); + *ifThen->mutable_then() = ANY_CAST( + ::substrait::proto::Expression, + visitExpression(ctx->expression(currExprNum + 1))); + currExprNum += 2; + } + if (currExprNum + 1 <= totalExprCount) { + // Use the last expression as the else clause. + *expr.mutable_if_then()->mutable_else_() = ANY_CAST( + ::substrait::proto::Expression, + visitExpression(ctx->expression(currExprNum))); + } + return expr; +} + std::any SubstraitPlanRelationVisitor::visitExpressionFunctionUse( SubstraitPlanParser::ExpressionFunctionUseContext* ctx) { ::substrait::proto::Expression expr; std::string funcName = ctx->id()->getText(); + if (absl::AsciiStrToLower(funcName) == "ifthen") { + return visitExpressionIfThenUse(ctx); + } uint32_t funcReference = 0; auto symbol = symbolTable_->lookupSymbolByName(funcName); if (symbol == nullptr || symbol->type != SymbolType::kFunction) { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h index a41e8e2b..95a0f335 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h @@ -8,6 +8,7 @@ #include "substrait/textplan/parser/SubstraitPlanTypeVisitor.h" namespace substrait::proto { +class Expression; class Expression_Literal; class Expression_Literal_Map_KeyValue; class NamedStruct; @@ -71,6 +72,10 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { // visitExpression is a new method delegating to the methods below. std::any visitExpression(SubstraitPlanParser::ExpressionContext* ctx); + // visitExpressionIfThenUse handles the built-in IFTHEN function-like syntax. + ::substrait::proto::Expression visitExpressionIfThenUse( + SubstraitPlanParser::ExpressionFunctionUseContext* ctx); + std::any visitExpressionFunctionUse( SubstraitPlanParser::ExpressionFunctionUseContext* ctx) override; diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 57680238..6b807a10 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -69,6 +69,11 @@ TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { std::string outputText = SymbolTablePrinter::outputToText(textResult.getSymbolTable()); + ASSERT_THAT(textResult, AllOf(ParsesOk(), HasErrors({}))) + << std::endl + << "Intermediate result:" << std::endl + << addLineNumbers(outputText) << std::endl; + auto stream = loadTextString(outputText); auto result = parseStream(stream); ASSERT_NO_THROW(auto outputBinary = SymbolTablePrinter::outputToBinaryPlan(