diff --git a/src/substrait/textplan/ParseResult.cpp b/src/substrait/textplan/ParseResult.cpp index 75be2807..41ac974d 100644 --- a/src/substrait/textplan/ParseResult.cpp +++ b/src/substrait/textplan/ParseResult.cpp @@ -2,8 +2,7 @@ #include "substrait/textplan/ParseResult.h" -#include -#include +#include namespace io::substrait::textplan { diff --git a/src/substrait/textplan/SymbolTable.cpp b/src/substrait/textplan/SymbolTable.cpp index e3de4380..bdab07ff 100644 --- a/src/substrait/textplan/SymbolTable.cpp +++ b/src/substrait/textplan/SymbolTable.cpp @@ -132,13 +132,13 @@ const SymbolInfo* SymbolTable::lookupSymbolByName( return symbols_[itr->second].get(); } -const SymbolInfo& SymbolTable::lookupSymbolByLocation( +const SymbolInfo* SymbolTable::lookupSymbolByLocation( const Location& location) const { auto itr = symbolsByLocation_.find(location); if (itr == symbolsByLocation_.end()) { - return SymbolInfo::kUnknown; + return nullptr; } - return *symbols_[itr->second]; + return symbols_[itr->second].get(); } const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type) diff --git a/src/substrait/textplan/SymbolTable.h b/src/substrait/textplan/SymbolTable.h index f033be34..19e6d419 100644 --- a/src/substrait/textplan/SymbolTable.h +++ b/src/substrait/textplan/SymbolTable.h @@ -146,7 +146,7 @@ class SymbolTable { [[nodiscard]] const SymbolInfo* lookupSymbolByName( const std::string& name) const; - [[nodiscard]] const SymbolInfo& lookupSymbolByLocation( + [[nodiscard]] const SymbolInfo* lookupSymbolByLocation( const Location& location) const; [[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type) diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index c0217f84..9e87ad9c 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -74,83 +74,9 @@ void localFileToText( } std::string typeToText(const ::substrait::proto::Type& type) { - switch (type.kind_case()) { - case ::substrait::proto::Type::kBool: - if (type.bool_().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "bool?"; - } - return "bool"; - case ::substrait::proto::Type::kI8: - if (type.i8().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "i8?"; - } - return "i8"; - case ::substrait::proto::Type::kI16: - if (type.i16().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "i16?"; - } - return "i16"; - case ::substrait::proto::Type::kI32: - if (type.i32().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "i32?"; - } - return "i32"; - case ::substrait::proto::Type::kI64: - if (type.i64().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "i64?"; - } - return "i64"; - case ::substrait::proto::Type::kFp32: - if (type.fp32().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "fp32?"; - } - return "fp32"; - case ::substrait::proto::Type::kFp64: - if (type.fp64().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "fp64?"; - } - return "fp64"; - case ::substrait::proto::Type::kString: - if (type.string().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "string?"; - } - return "string"; - case ::substrait::proto::Type::kDecimal: - if (type.string().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "decimal?"; - } - return "decimal"; - case ::substrait::proto::Type::kVarchar: - if (type.varchar().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "varchar?"; - } - return "varchar"; - case ::substrait::proto::Type::kFixedChar: - if (type.fixed_char().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "fixedchar?"; - } - return "fixedchar"; - case ::substrait::proto::Type::kDate: - if (type.date().nullability() == - ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return "date?"; - } - return "date"; - case ::substrait::proto::Type::KIND_NOT_SET: - default: - return "UNSUPPORTED_TYPE"; - } + SymbolTable symbolTable; + PlanPrinterVisitor visitor(symbolTable); + return visitor.typeToText(type); }; std::string relationToText( @@ -386,6 +312,88 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { return text.str(); } +void outputExtensionSpacesToBinaryPlan( + const SymbolTable& symbolTable, + ::substrait::proto::Plan* plan) { + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kExtensionSpace) { + continue; + } + + auto extensionData = + ANY_CAST(std::shared_ptr, info.blob); + auto uri = plan->add_extension_uris(); + uri->set_uri(info.name); + uri->set_extension_uri_anchor(extensionData->anchorReference); + } +} + +void outputFunctionsToBinaryPlan( + const SymbolTable& symbolTable, + ::substrait::proto::Plan* plan) { + std::map spaceNames; + std::set usedSpaces; + + // Look at the existing spaces. + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kExtensionSpace) { + continue; + } + + auto extensionData = + ANY_CAST(std::shared_ptr, info.blob); + spaceNames.insert( + std::make_pair(extensionData->anchorReference, info.name)); + } + + // Find any spaces that are used but undefined. + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kFunction) { + continue; + } + + auto extension = ANY_CAST(std::shared_ptr, info.blob); + if (extension->extensionUriReference.has_value()) { + usedSpaces.insert(extension->extensionUriReference.value()); + } + } + + // Output the extensions by space in the order they were encountered. + for (const uint32_t space : usedSpaces) { + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kFunction) { + continue; + } + + auto functionData = ANY_CAST(std::shared_ptr, info.blob); + if (functionData->extensionUriReference != space) { + continue; + } + + auto func = plan->add_extensions()->mutable_extension_function(); + func->set_function_anchor(functionData->anchor); + func->set_name(functionData->name); + + if (spaceNames.find(space) != spaceNames.end()) { + func->set_extension_uri_reference(space); + } + } + } + + for (const SymbolInfo& info : symbolTable) { + if (info.type != SymbolType::kFunction) { + continue; + } + + auto functionData = ANY_CAST(std::shared_ptr, info.blob); + if (!functionData->extensionUriReference.has_value()) { + auto func = plan->add_extensions()->mutable_extension_function(); + func->set_function_anchor(functionData->anchor); + func->set_name(functionData->name); + } + } +} + } // namespace std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { @@ -631,6 +639,8 @@ void SymbolTablePrinter::addInputsToRelation( ::substrait::proto::Plan SymbolTablePrinter::outputToBinaryPlan( const SymbolTable& symbolTable) { ::substrait::proto::Plan plan; + outputExtensionSpacesToBinaryPlan(symbolTable, &plan); + outputFunctionsToBinaryPlan(symbolTable, &plan); for (const SymbolInfo& info : symbolTable) { if (info.type != SymbolType::kRelation) { continue; diff --git a/src/substrait/textplan/converter/CMakeLists.txt b/src/substrait/textplan/converter/CMakeLists.txt index 182225fa..44594716 100644 --- a/src/substrait/textplan/converter/CMakeLists.txt +++ b/src/substrait/textplan/converter/CMakeLists.txt @@ -17,8 +17,13 @@ set(TEXTPLAN_SRCS add_library(substrait_textplan_converter ${TEXTPLAN_SRCS}) target_link_libraries( - substrait_textplan_converter substrait_common substrait_expression - substrait_proto symbol_table error_listener) + substrait_textplan_converter + substrait_common + substrait_expression + substrait_proto + symbol_table + error_listener + date::date) if(${SUBSTRAIT_CPP_BUILD_TESTING}) add_subdirectory(tests) diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index aedfe9f7..53a5ce24 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -267,12 +267,12 @@ std::any InitialPlanProtoVisitor::visitNamedStruct( void InitialPlanProtoVisitor::addFieldsToRelation( const std::shared_ptr& relationData, const ::substrait::proto::Rel& relation) { - auto symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); - if (symbol == SymbolInfo::kUnknown || symbol.type != SymbolType::kRelation) { + auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); + if (symbol == nullptr || symbol->type != SymbolType::kRelation) { return; } auto symbolRelationData = - ANY_CAST(std::shared_ptr, symbol.blob); + ANY_CAST(std::shared_ptr, symbol->blob); for (const auto& field : symbolRelationData->fieldReferences) { relationData->fieldReferences.push_back(field); } diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index bf4058c2..da5cd87f 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -9,11 +9,11 @@ namespace io::substrait::textplan { std::shared_ptr PipelineVisitor::getRelationData( const google::protobuf::Message& relation) { - auto symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); - if (symbol == SymbolInfo::kUnknown) { + auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); + if (symbol == nullptr) { return nullptr; } - return ANY_CAST(std::shared_ptr, symbol.blob); + return ANY_CAST(std::shared_ptr, symbol->blob); } std::any PipelineVisitor::visitRelation( @@ -25,92 +25,92 @@ std::any PipelineVisitor::visitRelation( // No relations beyond this one. break; case ::substrait::proto::Rel::RelTypeCase::kFilter: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.filter().input())); - relationData->continuingPipeline = &inputSymbol; + relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kFetch: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.fetch().input())); - relationData->continuingPipeline = &inputSymbol; + relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kAggregate: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.aggregate().input())); - relationData->continuingPipeline = &inputSymbol; + relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kSort: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.sort().input())); - relationData->continuingPipeline = &inputSymbol; + relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kJoin: { - const auto& leftSymbol = symbolTable_->lookupSymbolByLocation( + const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.join().left())); - const auto& rightSymbol = symbolTable_->lookupSymbolByLocation( + const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.join().right())); - relationData->newPipelines.push_back(&leftSymbol); - relationData->newPipelines.push_back(&rightSymbol); + relationData->newPipelines.push_back(leftSymbol); + relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::RelTypeCase::kProject: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.project().input())); - relationData->continuingPipeline = &inputSymbol; + relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kSet: for (const auto& rel : relation.set().inputs()) { - const auto& inputSymbol = + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation(Location(&rel)); - relationData->newPipelines.push_back(&inputSymbol); + relationData->newPipelines.push_back(inputSymbol); } break; case ::substrait::proto::Rel::RelTypeCase::kExtensionSingle: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.extension_single().input())); - relationData->continuingPipeline = &inputSymbol; + relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kExtensionMulti: for (const auto& rel : relation.extension_multi().inputs()) { - const auto& inputSymbol = + const auto* inputSymbol = symbolTable_->lookupSymbolByLocation(Location(&rel)); - relationData->newPipelines.push_back(&inputSymbol); + relationData->newPipelines.push_back(inputSymbol); } break; case ::substrait::proto::Rel::RelTypeCase::kExtensionLeaf: // No children. break; case ::substrait::proto::Rel::RelTypeCase::kCross: { - const auto& leftSymbol = symbolTable_->lookupSymbolByLocation( + const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.cross().left())); - const auto& rightSymbol = symbolTable_->lookupSymbolByLocation( + const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.cross().right())); - relationData->newPipelines.push_back(&leftSymbol); - relationData->newPipelines.push_back(&rightSymbol); + relationData->newPipelines.push_back(leftSymbol); + relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::RelTypeCase::kHashJoin: { - const auto& leftSymbol = symbolTable_->lookupSymbolByLocation( + const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.hash_join().left())); - const auto& rightSymbol = symbolTable_->lookupSymbolByLocation( + const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.hash_join().right())); - relationData->newPipelines.push_back(&leftSymbol); - relationData->newPipelines.push_back(&rightSymbol); + relationData->newPipelines.push_back(leftSymbol); + relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: { - const auto& leftSymbol = symbolTable_->lookupSymbolByLocation( + const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.merge_join().left())); - const auto& rightSymbol = symbolTable_->lookupSymbolByLocation( + const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.merge_join().right())); - relationData->newPipelines.push_back(&leftSymbol); - relationData->newPipelines.push_back(&rightSymbol); + relationData->newPipelines.push_back(leftSymbol); + relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::REL_TYPE_NOT_SET: @@ -122,19 +122,19 @@ std::any PipelineVisitor::visitRelation( std::any PipelineVisitor::visitPlanRelation( const ::substrait::proto::PlanRel& relation) { - auto symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); - auto relationData = ANY_CAST(std::shared_ptr, symbol.blob); + auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); + auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); switch (relation.rel_type_case()) { case ::substrait::proto::PlanRel::kRel: { const auto& relSymbol = symbolTable_->lookupSymbolByLocation(Location(&relation.rel())); - relationData->newPipelines.push_back(&relSymbol); + relationData->newPipelines.push_back(relSymbol); break; } case ::substrait::proto::PlanRel::kRoot: { const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( Location(&relation.root().input())); - relationData->newPipelines.push_back(&inputSymbol); + relationData->newPipelines.push_back(inputSymbol); break; } case ::substrait::proto::PlanRel::REL_TYPE_NOT_SET: diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp index 6e791255..dd6e869c 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp @@ -2,14 +2,13 @@ #include "substrait/textplan/converter/PlanPrinterVisitor.h" -#include #include #include +#include "date/date.h" #include "substrait/expression/DecimalLiteral.h" #include "substrait/proto/ProtoUtils.h" #include "substrait/proto/algebra.pb.h" -#include "substrait/proto/plan.pb.h" #include "substrait/textplan/Any.h" #include "substrait/textplan/Finally.h" #include "substrait/textplan/StructuredSymbolData.h" @@ -38,6 +37,28 @@ std::string stringEscape(std::string_view str) { return result.str(); } +std::string invocationToString( + ::substrait::proto::AggregateFunction_AggregationInvocation invocation) { + switch (invocation) { + case ::substrait::proto:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL: + return "all"; + case ::substrait::proto:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_DISTINCT: + return "distinct"; + case ::substrait::proto:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_UNSPECIFIED: + return "unspecified"; + case ::substrait::proto:: + AggregateFunction_AggregationInvocation_AggregateFunction_AggregationInvocation_INT_MIN_SENTINEL_DO_NOT_USE_: + case ::substrait::proto:: + AggregateFunction_AggregationInvocation_AggregateFunction_AggregationInvocation_INT_MAX_SENTINEL_DO_NOT_USE_: + break; + } + // We shouldn't reach here but return something to make the compiler happy. + return "unspecified"; +} + } // namespace std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { @@ -68,6 +89,11 @@ std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { return text.str(); } +std::string PlanPrinterVisitor::typeToText( + const ::substrait::proto::Type& type) { + return ANY_CAST(std::string, visitType(type)); +} + std::string PlanPrinterVisitor::lookupFieldReference(uint32_t field_reference) { if (*currentScope_ != SymbolInfo::kUnknown) { auto relationData = @@ -110,62 +136,87 @@ std::any PlanPrinterVisitor::visitType(const ::substrait::proto::Type& type) { case ::substrait::proto::Type::kBool: if (type.bool_().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_bool"); + return std::string("bool?"); } return std::string("bool"); case ::substrait::proto::Type::kI8: if (type.i8().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_i8"); + return std::string("i8?"); } return std::string("i8"); case ::substrait::proto::Type::kI16: if (type.i16().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_i16"); + return std::string("i16?"); } return std::string("i16"); case ::substrait::proto::Type::kI32: if (type.i32().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_i32"); + return std::string("i32?"); } return std::string("i32"); case ::substrait::proto::Type::kI64: if (type.i64().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_i64"); + return std::string("i64?"); } return std::string("i64"); case ::substrait::proto::Type::kFp32: if (type.fp32().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_fp32"); + return std::string("fp32?"); } return std::string("fp32"); case ::substrait::proto::Type::kFp64: if (type.fp64().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_fp64"); + return std::string("fp64?"); } return std::string("fp64"); case ::substrait::proto::Type::kString: if (type.string().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_string"); + return std::string("string?"); } return std::string("string"); - case ::substrait::proto::Type::kDecimal: - if (type.string().nullability() == + case ::substrait::proto::Type::kDecimal: { + std::stringstream result; + result << "decimal"; + if (type.decimal().nullability() == ::substrait::proto::Type::NULLABILITY_NULLABLE) { - return std::string("opt_decimal"); + result << '?'; } - return std::string("decimal"); - case ::substrait::proto::Type::kVarchar: - return std::string("varchar"); - case ::substrait::proto::Type::kFixedChar: - return std::string("fixedchar"); + result << "<" << type.decimal().precision() << ","; + result << type.decimal().scale() << ">"; + return result.str(); + } + case ::substrait::proto::Type::kVarchar: { + std::stringstream result; + result << "varchar"; + if (type.varchar().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) { + result << "?"; + } + result << "<" << type.varchar().length() << ">"; + return result.str(); + } + case ::substrait::proto::Type::kFixedChar: { + std::stringstream result; + result << "fixedchar"; + if (type.fixed_char().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) { + result << "?"; + } + result << "<" << type.fixed_char().length() << ">"; + return result.str(); + } case ::substrait::proto::Type::kDate: + if (type.date().nullability() == + ::substrait::proto::Type::NULLABILITY_NULLABLE) { + return std::string("date?"); + } return std::string("date"); case ::substrait::proto::Type::KIND_NOT_SET: errorListener_->addError( @@ -226,12 +277,9 @@ std::any PlanPrinterVisitor::visitLiteral( text << literal.fp64() << "_fp64"; break; case ::substrait::proto::Expression::Literal ::kDate: { - // TODO -- Format this as a date instead of a delta since an epoch. - if (literal.date() >= 0) { - text << "\"epoch+" << literal.date() << " days\"_date"; - } else { - text << "\"epoch" << literal.date() << " days\"_date"; - } + auto refDate = date::sys_days{}; + date::sys_days newDate = refDate + date::days{literal.date()}; + text << '"' << date::year_month_day{newDate} << "\"_date"; break; } case ::substrait::proto::Expression::Literal::kString: @@ -245,43 +293,17 @@ std::any PlanPrinterVisitor::visitLiteral( literal.ShortDebugString()); return std::string("UNSUPPORTED_LITERAL_TYPE"); case ::substrait::proto::Expression_Literal::kIntervalYearToMonth: { - text << "{"; - bool hasPreviousText = false; - if (literal.interval_year_to_month().years() != 0) { - text << literal.interval_year_to_month().years() << "years"; - hasPreviousText = true; - } - if (literal.interval_year_to_month().months() != 0) { - if (hasPreviousText) { - text << ", "; - } - text << literal.interval_year_to_month().months() << "months"; - } - text << "}_interval_year"; // TODO - Change spec to better name. + text << "{" << literal.interval_year_to_month().years() << "_years" + << ", " << literal.interval_year_to_month().months() << "_months" + << "}_interval_year"; // TODO - Change spec to better name. break; } case ::substrait::proto::Expression_Literal::kIntervalDayToSecond: { - text << "{"; - bool hasPreviousText = false; - if (literal.interval_day_to_second().days() != 0) { - text << literal.interval_day_to_second().days() << "days"; - hasPreviousText = true; - } - if (literal.interval_day_to_second().seconds() != 0) { - if (hasPreviousText) { - text << ", "; - } - text << literal.interval_day_to_second().seconds() << "seconds"; - hasPreviousText = true; - } - if (literal.interval_day_to_second().microseconds() != 0) { - if (hasPreviousText) { - text << ", "; - } - text << literal.interval_day_to_second().microseconds() - << "microseconds"; - } - text << "}_interval_day"; // TODO - Change spec to better name. + text << "{" << literal.interval_day_to_second().days() << "_days" + << ", " << literal.interval_day_to_second().seconds() << "_seconds" + << ", " << literal.interval_day_to_second().microseconds() + << "_microseconds" + << "}_interval_day"; // TODO - Change spec to better name. break; } case ::substrait::proto::Expression_Literal::kFixedChar: @@ -592,11 +614,6 @@ std::any PlanPrinterVisitor::visitReferenceSegment( std::any PlanPrinterVisitor::visitExpression( const ::substrait::proto::Expression& expression) { - if (expression.rex_type_case() == - ::substrait::proto::Expression::RexTypeCase::REX_TYPE_NOT_SET) { - // TODO -- Remove this check after expressions are finished. - return std::string("EXPR-NOT-YET-IMPLEMENTED"); - } return BasePlanProtoVisitor::visitExpression(expression); } @@ -613,10 +630,10 @@ std::any PlanPrinterVisitor::visitRelation( // Mark the current scope for any operations within this relation. auto previousScope = currentScope_; auto resetCurrentScope = finally([&]() { currentScope_ = previousScope; }); - const SymbolInfo& symbol = + const SymbolInfo* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); - if (symbol != SymbolInfo::kUnknown) { - currentScope_ = &symbol; + if (symbol != nullptr) { + currentScope_ = symbol; } auto result = BasePlanProtoVisitor::visitRelation(relation); @@ -645,17 +662,17 @@ std::any PlanPrinterVisitor::visitReadRelation( case ::substrait::proto::ReadRel::READ_TYPE_NOT_SET: return ""; } - const auto& symbol = + const auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(*msg)); - if (symbol != SymbolInfo::kUnknown) { - text << " source " << symbol.name << ";\n"; + if (symbol != nullptr) { + text << " source " << symbol->name << ";\n"; } if (relation.has_base_schema()) { - const auto& schemaSymbol = symbolTable_->lookupSymbolByLocation( + const auto* schemaSymbol = symbolTable_->lookupSymbolByLocation( PROTO_LOCATION(relation.base_schema())); - if (schemaSymbol != SymbolInfo::kUnknown) { - text << " base_schema " << schemaSymbol.name << ";\n"; + if (schemaSymbol != nullptr) { + text << " base_schema " << schemaSymbol->name << ";\n"; } } if (relation.has_filter()) { @@ -681,7 +698,7 @@ std::any PlanPrinterVisitor::visitFilterRelation( const ::substrait::proto::FilterRel& relation) { std::stringstream text; if (relation.has_condition()) { - text << " condition " + text << " filter " << ANY_CAST(std::string, visitExpression(relation.condition())) << ";\n"; } @@ -716,9 +733,14 @@ std::any PlanPrinterVisitor::visitAggregateRelation( << ANY_CAST(std::string, visitAggregateFunction(measure.measure())) << ";\n"; if (measure.has_filter()) { - text << " filter " + - ANY_CAST(std::string, visitExpression(measure.filter())) - << ";\n"; + text << " filter " + << ANY_CAST(std::string, visitExpression(measure.filter())) << ";\n"; + } + if (measure.measure().invocation() != + ::substrait::proto:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_UNSPECIFIED) { + text << " invocation " + << invocationToString(measure.measure().invocation()) << ";\n"; } text << " }\n"; } diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.h b/src/substrait/textplan/converter/PlanPrinterVisitor.h index 825b3b4b..2be16b09 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.h +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.h @@ -31,6 +31,7 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { }; std::string printRelation(const SymbolInfo& symbol); + std::string typeToText(const ::substrait::proto::Type& type); private: std::string lookupFieldReference(uint32_t field_reference); diff --git a/src/substrait/textplan/converter/Tool.cpp b/src/substrait/textplan/converter/Tool.cpp index 44842055..9c3d652b 100644 --- a/src/substrait/textplan/converter/Tool.cpp +++ b/src/substrait/textplan/converter/Tool.cpp @@ -4,7 +4,7 @@ #include #endif -#include +#include #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/converter/LoadBinary.h" diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index e97b2d81..9bc1f689 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -408,7 +408,7 @@ std::vector getTestCases() { } filter relation filter { - condition functionref#4(field#2, 0.07_fp64); + filter functionref#4(field#2, 0.07_fp64); })"))), }, { @@ -461,7 +461,47 @@ std::vector getTestCases() { } filter relation filter { - condition functionref#4(field#2, 0.07_fp64); + filter functionref#4(field#2, 0.07_fp64); + })"))), + }, + { + "cast expression", + R"(relations: { + root: { + input: { + filter: { + condition: { + cast: { + type: { + fixed_char: { + length: 10, + type_variation_reference: 0, + nullability: NULLABILITY_REQUIRED + } + }, + input: { + literal: { + fixed_char: "HOUSEHOLD", + nullable: false, + type_variation_reference: 0 + } + }, + failure_behavior: FAILURE_BEHAVIOR_UNSPECIFIED + } + } + } + } + } + })", + AllOf( + HasSymbols({"filter", "root"}), + WhenSerialized(EqSquashingWhitespace( + R"(pipelines { + filter -> root; + } + + filter relation filter { + filter "HOUSEHOLD"_fixedchar<9> AS fixedchar<10>; })"))), }, { diff --git a/src/substrait/textplan/parser/CMakeLists.txt b/src/substrait/textplan/parser/CMakeLists.txt index dffd4bad..498d572c 100644 --- a/src/substrait/textplan/parser/CMakeLists.txt +++ b/src/substrait/textplan/parser/CMakeLists.txt @@ -10,6 +10,8 @@ add_library( SubstraitPlanPipelineVisitor.h SubstraitPlanRelationVisitor.cpp SubstraitPlanRelationVisitor.h + SubstraitPlanTypeVisitor.cpp + SubstraitPlanTypeVisitor.h ParseText.cpp ParseText.h SubstraitParserErrorListener.cpp) diff --git a/src/substrait/textplan/parser/ParseText.cpp b/src/substrait/textplan/parser/ParseText.cpp index 285c10de..554e4a26 100644 --- a/src/substrait/textplan/parser/ParseText.cpp +++ b/src/substrait/textplan/parser/ParseText.cpp @@ -2,14 +2,14 @@ #include "ParseText.h" +#include #include #include -#include -#include +#include +#include #include "SubstraitPlanLexer/SubstraitPlanLexer.h" #include "SubstraitPlanParser/SubstraitPlanParser.h" -#include "substrait/textplan/Any.h" #include "substrait/textplan/StructuredSymbolData.h" #include "substrait/textplan/parser/SubstraitParserErrorListener.h" #include "substrait/textplan/parser/SubstraitPlanPipelineVisitor.h" @@ -37,22 +37,28 @@ antlr4::ANTLRInputStream loadTextString(std::string_view text) { } ParseResult parseStream(antlr4::ANTLRInputStream stream) { + io::substrait::textplan::SubstraitParserErrorListener errorListener; + SubstraitPlanLexer lexer(&stream); + lexer.removeErrorListeners(); + lexer.addErrorListener(&errorListener); antlr4::CommonTokenStream tokens(&lexer); tokens.fill(); SubstraitPlanParser parser(&tokens); parser.removeErrorListeners(); - io::substrait::textplan::SubstraitParserErrorListener parserErrorListener; - parser.addErrorListener(&parserErrorListener); + parser.addErrorListener(&errorListener); auto* tree = parser.plan(); - auto visitor = std::make_shared(); + SymbolTable visitorSymbolTable; + auto visitorErrorListener = std::make_shared(); + auto visitor = std::make_shared( + visitorSymbolTable, visitorErrorListener); try { visitor->visitPlan(tree); } catch (...) { - parserErrorListener.syntaxError( + errorListener.syntaxError( &parser, nullptr, /*line=*/1, @@ -66,7 +72,7 @@ ParseResult parseStream(antlr4::ANTLRInputStream stream) { try { pipelineVisitor->visitPlan(tree); } catch (...) { - parserErrorListener.syntaxError( + errorListener.syntaxError( &parser, nullptr, /*line=*/1, @@ -79,8 +85,17 @@ ParseResult parseStream(antlr4::ANTLRInputStream stream) { *pipelineVisitor->getSymbolTable(), pipelineVisitor->getErrorListener()); try { relationVisitor->visitPlan(tree); + } catch (std::invalid_argument ex) { + // Catches the any_cast exception and logs a useful error message. + errorListener.syntaxError( + &parser, + nullptr, + /*line=*/1, + /*charPositionInLine=*/1, + ex.what(), + std::current_exception()); } catch (...) { - parserErrorListener.syntaxError( + errorListener.syntaxError( &parser, nullptr, /*line=*/1, @@ -92,7 +107,7 @@ ParseResult parseStream(antlr4::ANTLRInputStream stream) { auto finalSymbolTable = relationVisitor->getSymbolTable(); return { *finalSymbolTable, - parserErrorListener.getErrorMessages(), + errorListener.getErrorMessages(), relationVisitor->getErrorListener()->getErrorMessages()}; } diff --git a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp index 7cdbe3d2..a08c682e 100644 --- a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp @@ -2,10 +2,8 @@ #include "substrait/textplan/parser/SubstraitPlanPipelineVisitor.h" -#include #include -#include "SubstraitPlanLexer/SubstraitPlanLexer.h" #include "SubstraitPlanParser/SubstraitPlanParser.h" #include "substrait/textplan/Any.h" #include "substrait/textplan/Location.h" @@ -95,13 +93,13 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( const SymbolInfo* leftSymbol = &SymbolInfo::kUnknown; if (ctx->pipeline() != nullptr) { leftSymbol = - &symbolTable_->lookupSymbolByLocation(PARSER_LOCATION(ctx->pipeline())); + symbolTable_->lookupSymbolByLocation(PARSER_LOCATION(ctx->pipeline())); } const SymbolInfo* rightSymbol = &SymbolInfo::kUnknown; if (dynamic_cast(ctx->parent)->getRuleIndex() == SubstraitPlanParser::RulePipeline) { rightSymbol = - &symbolTable_->lookupSymbolByLocation(PARSER_LOCATION(ctx->parent)); + symbolTable_->lookupSymbolByLocation(PARSER_LOCATION(ctx->parent)); } const SymbolInfo* rightmostSymbol = rightSymbol; if (*rightSymbol != SymbolInfo::kUnknown) { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index 794ba660..fc8ee51d 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -9,6 +9,7 @@ #include #include "SubstraitPlanParser/SubstraitPlanParser.h" +#include "SubstraitPlanTypeVisitor.h" #include "absl/strings/numbers.h" #include "date/tz.h" #include "substrait/expression/DecimalLiteral.h" @@ -19,12 +20,15 @@ #include "substrait/textplan/Location.h" #include "substrait/textplan/StructuredSymbolData.h" #include "substrait/textplan/SymbolTable.h" -#include "substrait/type/Type.h" namespace io::substrait::textplan { namespace { +std::string kAggregationPhasePrefix = "aggregationphase"; +std::string kAggregationInvocationPrefix = "aggregationinvocation"; +std::string kSortDirectionPrefix = "sortdirection"; + enum RelationFilterBehavior { kDefault = 0, kBestEffort = 1, @@ -40,8 +44,8 @@ std::string toLower(const std::string& str) { } // Yields true if the string 'haystack' starts with the string 'needle'. -bool startsWith(const std::string& haystack, const std::string& needle) { - return strncmp(haystack.c_str(), needle.c_str(), needle.size()) == 0; +bool startsWith(const std::string& haystack, std::string_view needle) { + return strncmp(haystack.c_str(), needle.data(), needle.size()) == 0; } void setNullable(::substrait::proto::Type* type) { @@ -204,6 +208,27 @@ void setRelationType( } } +std::string normalizeProtoEnum(std::string_view text, std::string_view prefix) { + std::string result{text}; + // Remove non-alphabetic characters. + result.erase( + std::remove_if( + result.begin(), + result.end(), + [](auto const& c) -> bool { return !std::isalpha(c); }), + result.end()); + // Lowercase. + std::transform( + result.begin(), result.end(), result.begin(), [](unsigned char c) { + return std::tolower(c); + }); + // Remove the prefix if it exists. + if (startsWith(result, prefix)) { + result = result.substr(prefix.length()); + } + return result; +} + } // namespace std::any SubstraitPlanRelationVisitor::aggregateResult( @@ -219,25 +244,25 @@ std::any SubstraitPlanRelationVisitor::aggregateResult( std::any SubstraitPlanRelationVisitor::visitRelation( SubstraitPlanParser::RelationContext* ctx) { // Create the relation before visiting our children, so they can update it. - auto symbol = symbolTable_->lookupSymbolByLocation(Location(ctx)); - if (symbol == SymbolInfo::kUnknown) { + auto* symbol = symbolTable_->lookupSymbolByLocation(Location(ctx)); + if (symbol == nullptr) { // This error has been previously dealt with thus we can safely skip it. return defaultResult(); } - auto relationData = ANY_CAST(std::shared_ptr, symbol.blob); + auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); ::substrait::proto::Rel relation; - auto relationType = ANY_CAST(RelationType, symbol.subtype); + auto relationType = ANY_CAST(RelationType, symbol->subtype); setRelationType(relationType, &relation); relationData->relation = relation; - symbolTable_->updateLocation(symbol, PROTO_LOCATION(relationData->relation)); + symbolTable_->updateLocation(*symbol, PROTO_LOCATION(relationData->relation)); // Mark the current scope for any operations within this relation. auto previousScope = currentRelationScope_; auto resetCurrentScope = finally([&]() { currentRelationScope_ = previousScope; }); - currentRelationScope_ = &symbol; + currentRelationScope_ = symbol; visitChildren(ctx); @@ -276,12 +301,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationFilter( visitRelation_filter_behavior(ctx->relation_filter_behavior())); } - auto parentSymbol = symbolTable_->lookupSymbolByLocation( + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( Location(dynamic_cast(ctx->parent))); auto parentRelationData = - ANY_CAST(std::shared_ptr, parentSymbol.blob); + ANY_CAST(std::shared_ptr, parentSymbol->blob); auto result = SubstraitPlanRelationVisitor::visitChildren(ctx); - auto parentRelationType = ANY_CAST(RelationType, parentSymbol.subtype); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); switch (parentRelationType) { case RelationType::kRead: switch (behavior) { @@ -346,6 +371,11 @@ std::any SubstraitPlanRelationVisitor::visitRelationFilter( "specified."); break; } + if (result.type() != typeid(::substrait::proto::Expression)) { + errorListener_->addError( + ctx->getStart(), "Could not parse as an expression."); + return defaultResult(); + } *parentRelationData->relation.mutable_filter()->mutable_condition() = ANY_CAST(::substrait::proto::Expression, result); } else { @@ -365,11 +395,11 @@ std::any SubstraitPlanRelationVisitor::visitRelationFilter( std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( SubstraitPlanParser::RelationUsesSchemaContext* ctx) { - auto parentSymbol = symbolTable_->lookupSymbolByLocation( + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( Location(dynamic_cast(ctx->parent))); auto parentRelationData = - ANY_CAST(std::shared_ptr, parentSymbol.blob); - auto parentRelationType = ANY_CAST(RelationType, parentSymbol.subtype); + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); if (parentRelationType == RelationType::kRead) { auto schemaName = ctx->id()->getText(); @@ -385,9 +415,7 @@ std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( continue; } schema->add_names(sym.name); - auto typeText = ANY_CAST(std::string, sym.blob); - // TODO -- Use the location of the schema item for errors. - auto typeProto = textToTypeProto(ctx->getStart(), typeText); + auto typeProto = ANY_CAST(::substrait::proto::Type, sym.blob); if (typeProto.kind_case() != ::substrait::proto::Type::KIND_NOT_SET) { *schema->mutable_struct_()->add_types() = typeProto; } @@ -403,12 +431,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( std::any SubstraitPlanRelationVisitor::visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) { - auto parentSymbol = symbolTable_->lookupSymbolByLocation( + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( Location(dynamic_cast(ctx->parent))); auto parentRelationData = - ANY_CAST(std::shared_ptr, parentSymbol.blob); + ANY_CAST(std::shared_ptr, parentSymbol->blob); auto result = SubstraitPlanRelationVisitor::visitChildren(ctx); - auto parentRelationType = ANY_CAST(RelationType, parentSymbol.subtype); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); switch (parentRelationType) { case RelationType::kJoin: if (parentRelationData->relation.join().has_expression()) { @@ -434,35 +462,207 @@ std::any SubstraitPlanRelationVisitor::visitRelationExpression( return defaultResult(); } -std::any SubstraitPlanRelationVisitor::visitExpression( - SubstraitPlanParser::ExpressionContext* ctx) { - if (auto* funcUseCtx = - dynamic_cast( - ctx)) { - return visitExpressionFunctionUse(funcUseCtx); - } else if ( - auto* constantCtx = - dynamic_cast(ctx)) { - return visitExpressionConstant(constantCtx); - } else if ( - auto* columnCtx = - dynamic_cast(ctx)) { - return visitExpressionColumn(columnCtx); - } else if ( - auto* castCtx = - dynamic_cast(ctx)) { - return visitExpressionCast(castCtx); +std::any SubstraitPlanRelationVisitor::visitRelationGrouping( + SubstraitPlanParser::RelationGroupingContext* ctx) { + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( + Location(dynamic_cast(ctx->parent))); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto result = SubstraitPlanRelationVisitor::visitChildren(ctx); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); + switch (parentRelationType) { + case RelationType::kAggregate: { + if (parentRelationData->relation.aggregate().groupings_size() == 0) { + parentRelationData->relation.mutable_aggregate()->add_groupings(); + } + // Always add new expressions to the first groupings group. + auto newExpr = parentRelationData->relation.mutable_aggregate() + ->mutable_groupings(0) + ->add_grouping_expressions(); + *newExpr = ANY_CAST(::substrait::proto::Expression, result); + if (newExpr->has_selection()) { + newExpr->mutable_selection()->mutable_root_reference(); + } + break; + } + default: + errorListener_->addError( + ctx->getStart(), + "Groupings are not permitted for this kind of relation."); + break; } return defaultResult(); } +std::any SubstraitPlanRelationVisitor::visitRelationMeasure( + SubstraitPlanParser::RelationMeasureContext* ctx) { + // Construct the measure. + ::substrait::proto::AggregateRel_Measure measure; + auto invocation = ::substrait::proto:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_UNSPECIFIED; + std::vector<::substrait::proto::SortField> sorts; + for (auto detail : ctx->measure_detail()) { + auto detailItem = ANY_CAST( + ::substrait::proto::AggregateRel_Measure, visitMeasure_detail(detail)); + if (detail->getStart()->getType() == SubstraitPlanParser::MEASURE) { + if (measure.has_measure()) { + errorListener_->addError( + detail->getStart(), + "A measure expression has already been provided for this measure."); + break; + } + *measure.mutable_measure() = detailItem.measure(); + } else if (detail->getStart()->getType() == SubstraitPlanParser::FILTER) { + if (measure.has_filter()) { + errorListener_->addError( + detail->getStart(), + "A filter has already been provided for this measure."); + break; + } + *measure.mutable_filter() = detailItem.filter(); + } else if ( + detail->getStart()->getType() == SubstraitPlanParser::INVOCATION) { + invocation = detailItem.measure().invocation(); + } else if (detail->getStart()->getType() == SubstraitPlanParser::SORT) { + auto newSorts = detailItem.measure().sorts(); + sorts.insert(sorts.end(), newSorts.begin(), newSorts.end()); + } + } + if (invocation != + ::substrait::proto:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_UNSPECIFIED) { + measure.mutable_measure()->set_invocation(invocation); + } + for (const auto& sort : sorts) { + *measure.mutable_measure()->add_sorts() = sort; + } + + // Add it to our relation. + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( + Location(dynamic_cast(ctx->parent))); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); + switch (parentRelationType) { + case RelationType::kAggregate: + *parentRelationData->relation.mutable_aggregate()->add_measures() = + measure; + break; + default: + errorListener_->addError( + ctx->getStart(), + "Measures are not permitted for this kind of relation."); + break; + } + return defaultResult(); +} + +int32_t SubstraitPlanRelationVisitor::visitAggregationInvocation( + SubstraitPlanParser::IdContext* ctx) { + std::string text = + normalizeProtoEnum(ctx->getText(), kAggregationInvocationPrefix); + if (text == "unspecified") { + return ::substrait::proto::AggregateFunction:: + AGGREGATION_INVOCATION_UNSPECIFIED; + } else if (text == "all") { + return ::substrait::proto::AggregateFunction::AGGREGATION_INVOCATION_ALL; + } else if (text == "distinct") { + return ::substrait::proto::AggregateFunction:: + AGGREGATION_INVOCATION_DISTINCT; + } + this->errorListener_->addError( + ctx->getStart(), + "Unrecognized aggregation invocation: " + ctx->getText()); + return ::substrait::proto::AggregateFunction:: + AGGREGATION_INVOCATION_UNSPECIFIED; +} + +int32_t SubstraitPlanRelationVisitor::visitAggregationPhase( + SubstraitPlanParser::IdContext* ctx) { + std::string text = + normalizeProtoEnum(ctx->getText(), kAggregationPhasePrefix); + if (text == "unspecified") { + return ::substrait::proto::AGGREGATION_PHASE_UNSPECIFIED; + } else if (text == "initialtointermediate") { + return ::substrait::proto::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE; + } else if (text == "intermediatetointermediate") { + return ::substrait::proto::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE; + } else if (text == "initialtoresult") { + return ::substrait::proto::AGGREGATION_PHASE_INITIAL_TO_RESULT; + } else if (text == "intermediatetoresult") { + return ::substrait::proto::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT; + } + this->errorListener_->addError( + ctx->getStart(), "Unrecognized aggregation phase: " + ctx->getText()); + return ::substrait::proto::AGGREGATION_PHASE_UNSPECIFIED; +} + +std::any SubstraitPlanRelationVisitor::visitMeasure_detail( + SubstraitPlanParser::Measure_detailContext* ctx) { + ::substrait::proto::AggregateRel_Measure measure; + switch (ctx->getStart()->getType()) { + case SubstraitPlanParser::MEASURE: { + auto function = measure.mutable_measure(); + auto result = visitExpression(ctx->expression()); + auto expr = ANY_CAST(::substrait::proto::Expression, result); + if (expr.has_scalar_function()) { + const auto& scalarFunc = expr.scalar_function(); + function->set_function_reference(scalarFunc.function_reference()); + for (const auto& arg : scalarFunc.arguments()) { + *function->add_arguments() = arg; + } + for (const auto& option : scalarFunc.options()) { + *function->add_options() = option; + } + if (scalarFunc.has_output_type()) { + *function->mutable_output_type() = scalarFunc.output_type(); + } + if (ctx->literal_complex_type() != nullptr) { + // The version here overrides any that might be in the function. + *function->mutable_output_type() = ANY_CAST( + ::substrait::proto::Type, + visitLiteral_complex_type(ctx->literal_complex_type())); + } + if (ctx->id() != nullptr) { + measure.mutable_measure()->set_phase( + static_cast<::substrait::proto::AggregationPhase>( + visitAggregationPhase(ctx->id()))); + } + } else { + errorListener_->addError( + ctx->id()->getStart(), + "Expected an expression utilizing a function here."); + } + + return measure; + } + case SubstraitPlanParser::FILTER: + *measure.mutable_filter() = ANY_CAST( + ::substrait::proto::Expression, visitExpression(ctx->expression())); + return measure; + case SubstraitPlanParser::INVOCATION: + measure.mutable_measure()->set_invocation( + static_cast< + ::substrait::proto::AggregateFunction_AggregationInvocation>( + visitAggregationInvocation(ctx->id()))); + return measure; + case SubstraitPlanParser::SORT: + *measure.mutable_measure()->add_sorts() = ANY_CAST( + ::substrait::proto::SortField, visitSort_field(ctx->sort_field())); + return measure; + default: + // Alert that this kind of measure detail is not in the grammar. + return measure; + } +} + std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( SubstraitPlanParser::RelationSourceReferenceContext* ctx) { - auto parentSymbol = symbolTable_->lookupSymbolByLocation( + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( Location(dynamic_cast(ctx->parent))); auto parentRelationData = - ANY_CAST(std::shared_ptr, parentSymbol.blob); - auto parentRelationType = ANY_CAST(RelationType, parentSymbol.subtype); + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); if (parentRelationType == RelationType::kRead) { auto sourceName = ctx->source_reference()->id()->getText(); @@ -488,13 +688,56 @@ std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( return defaultResult(); } +std::any SubstraitPlanRelationVisitor::visitRelationSort( + SubstraitPlanParser::RelationSortContext* ctx) { + auto* parentSymbol = symbolTable_->lookupSymbolByLocation( + Location(dynamic_cast(ctx->parent))); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); + switch (parentRelationType) { + case RelationType::kSort: + *parentRelationData->relation.mutable_sort()->add_sorts() = ANY_CAST( + ::substrait::proto::SortField, visitSort_field(ctx->sort_field())); + break; + default: + errorListener_->addError( + ctx->getStart(), + "Sorts are not permitted for this kind of relation."); + break; + } + return defaultResult(); +} + +std::any SubstraitPlanRelationVisitor::visitExpression( + SubstraitPlanParser::ExpressionContext* ctx) { + if (auto* funcUseCtx = + dynamic_cast( + ctx)) { + return visitExpressionFunctionUse(funcUseCtx); + } else if ( + auto* constantCtx = + dynamic_cast(ctx)) { + return visitExpressionConstant(constantCtx); + } else if ( + auto* columnCtx = + dynamic_cast(ctx)) { + return visitExpressionColumn(columnCtx); + } else if ( + auto* castCtx = + dynamic_cast(ctx)) { + return visitExpressionCast(castCtx); + } + return defaultResult(); +} + std::any SubstraitPlanRelationVisitor::visitExpressionFunctionUse( SubstraitPlanParser::ExpressionFunctionUseContext* ctx) { ::substrait::proto::Expression expr; std::string funcName = ctx->id()->getText(); uint32_t funcReference = 0; auto symbol = symbolTable_->lookupSymbolByName(funcName); - if (symbol->type != SymbolType::kFunction) { + if (symbol == nullptr || symbol->type != SymbolType::kFunction) { errorListener_->addError( ctx->id()->getStart(), ctx->id()->getText() + " is not a function reference."); @@ -505,8 +748,13 @@ std::any SubstraitPlanRelationVisitor::visitExpressionFunctionUse( expr.mutable_scalar_function()->set_function_reference(funcReference); for (const auto& exp : ctx->expression()) { - auto newExpr = - ANY_CAST(::substrait::proto::Expression, visitExpression(exp)); + auto result = visitExpression(exp); + if (result.type() != typeid(::substrait::proto::Expression)) { + errorListener_->addError( + ctx->id()->getStart(), "Could not parse as an expression."); + return expr; + } + auto newExpr = ANY_CAST(::substrait::proto::Expression, result); *expr.mutable_scalar_function()->add_arguments()->mutable_value() = newExpr; } return expr; @@ -593,22 +841,6 @@ std::any SubstraitPlanRelationVisitor::visitConstant( return literal; } -std::any SubstraitPlanRelationVisitor::visitLiteral_specifier( - SubstraitPlanParser::Literal_specifierContext* ctx) { - // Provides detail for the width of the type. - return visitChildren(ctx); -} - -std::any SubstraitPlanRelationVisitor::visitLiteral_basic_type( - SubstraitPlanParser::Literal_basic_typeContext* ctx) { - return textToTypeProto(ctx->getStart(), ctx->getText()); -} - -std::any SubstraitPlanRelationVisitor::visitLiteral_complex_type( - SubstraitPlanParser::Literal_complex_typeContext* ctx) { - return textToTypeProto(ctx->getStart(), ctx->getText()); -} - std::any SubstraitPlanRelationVisitor::visitMap_literal( SubstraitPlanParser::Map_literalContext* ctx) { ::substrait::proto::Expression_Literal literal; @@ -971,7 +1203,8 @@ SubstraitPlanRelationVisitor::visitNumber( case ::substrait::proto::Type::kI8: { int32_t val = std::stoi(node->getText()); literal.set_i8(val); - if (literalType.i8().nullability()) { + if (literalType.i8().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -979,7 +1212,8 @@ SubstraitPlanRelationVisitor::visitNumber( case ::substrait::proto::Type::kI16: { int32_t val = std::stoi(node->getText()); literal.set_i16(val); - if (literalType.i16().nullability()) { + if (literalType.i16().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -987,7 +1221,8 @@ SubstraitPlanRelationVisitor::visitNumber( case ::substrait::proto::Type::kI32: { int32_t val = std::stoi(node->getText()); literal.set_i32(val); - if (literalType.i32().nullability()) { + if (literalType.i32().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -995,7 +1230,8 @@ SubstraitPlanRelationVisitor::visitNumber( case ::substrait::proto::Type::kI64: { int64_t val = std::stol(node->getText()); literal.set_i64(val); - if (literalType.i64().nullability()) { + if (literalType.i64().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -1003,7 +1239,8 @@ SubstraitPlanRelationVisitor::visitNumber( case ::substrait::proto::Type::kFp32: { float val = std::stof(node->getText()); literal.set_fp32(val); - if (literalType.fp32().nullability()) { + if (literalType.fp32().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -1011,7 +1248,8 @@ SubstraitPlanRelationVisitor::visitNumber( case ::substrait::proto::Type::kFp64: { double val = std::stod(node->getText()); literal.set_fp64(val); - if (literalType.fp64().nullability()) { + if (literalType.fp64().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -1027,7 +1265,8 @@ SubstraitPlanRelationVisitor::visitNumber( break; } *literal.mutable_decimal() = decimal.toProto(); - if (literalType.decimal().nullability()) { + if (literalType.decimal().nullability() == + ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE) { literal.set_nullable(true); } break; @@ -1243,160 +1482,37 @@ ::substrait::proto::Expression_Literal SubstraitPlanRelationVisitor::visitTime( return literal; } -::substrait::proto::Type SubstraitPlanRelationVisitor::textToTypeProto( - const antlr4::Token* token, - const std::string& typeText) { - std::shared_ptr decodedType; - try { - decodedType = Type::decode(typeText); - } catch (...) { - errorListener_->addError(token, "Failed to decode type."); - return ::substrait::proto::Type{}; +std::any SubstraitPlanRelationVisitor::visitSort_field( + SubstraitPlanParser::Sort_fieldContext* ctx) { + ::substrait::proto::SortField sort; + *sort.mutable_expr() = ANY_CAST( + ::substrait::proto::Expression, visitExpression(ctx->expression())); + if (ctx->id() != nullptr) { + sort.set_direction(static_cast<::substrait::proto::SortField_SortDirection>( + visitSortDirection(ctx->id()))); } - return typeToProto(token, *decodedType); + return sort; } -::substrait::proto::Type SubstraitPlanRelationVisitor::typeToProto( - const antlr4::Token* token, - const ParameterizedType& decodedType) { - ::substrait::proto::Type type; - auto nullValue = ::substrait::proto::Type_Nullability_NULLABILITY_UNSPECIFIED; - if (decodedType.nullable()) { - nullValue = ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE; - } - switch (decodedType.kind()) { - case TypeKind::kBool: - type.mutable_bool_()->set_nullability(nullValue); - break; - case TypeKind::kI8: - type.mutable_i8()->set_nullability(nullValue); - break; - case TypeKind::kI16: - type.mutable_i16()->set_nullability(nullValue); - break; - case TypeKind::kI32: - type.mutable_i32()->set_nullability(nullValue); - break; - case TypeKind::kI64: - type.mutable_i64()->set_nullability(nullValue); - break; - case TypeKind::kFp32: - type.mutable_fp32()->set_nullability(nullValue); - break; - case TypeKind::kFp64: - type.mutable_fp64()->set_nullability(nullValue); - break; - case TypeKind::kString: - type.mutable_string()->set_nullability(nullValue); - break; - case TypeKind::kBinary: - type.mutable_binary()->set_nullability(nullValue); - break; - case TypeKind::kTimestamp: - type.mutable_timestamp()->set_nullability(nullValue); - break; - case TypeKind::kDate: - type.mutable_date()->set_nullability(nullValue); - break; - case TypeKind::kTime: - type.mutable_time()->set_nullability(nullValue); - break; - case TypeKind::kIntervalYear: - type.mutable_interval_year()->set_nullability(nullValue); - break; - case TypeKind::kIntervalDay: - type.mutable_interval_day()->set_nullability(nullValue); - break; - case TypeKind::kTimestampTz: - type.mutable_timestamp_tz()->set_nullability(nullValue); - break; - case TypeKind::kUuid: - type.mutable_uuid()->set_nullability(nullValue); - break; - case TypeKind::kFixedChar: { - auto fixedChar = - reinterpret_cast(&decodedType); - if (fixedChar == nullptr) { - break; - } - try { - int32_t length = std::stoi(fixedChar->length()->value()); - type.mutable_fixed_char()->set_length(length); - } catch (...) { - errorListener_->addError(token, "Could not parse fixedchar length."); - } - type.mutable_fixed_char()->set_nullability(nullValue); - break; - } - case TypeKind::kVarchar: { - auto varChar = - reinterpret_cast(&decodedType); - if (varChar == nullptr) { - break; - } - try { - int32_t length = std::stoi(varChar->length()->value()); - type.mutable_varchar()->set_length(length); - } catch (...) { - errorListener_->addError(token, "Could not parse varchar length."); - } - type.mutable_varchar()->set_nullability(nullValue); - break; - } - case TypeKind::kFixedBinary: - type.mutable_fixed_binary()->set_nullability(nullValue); - break; - case TypeKind::kDecimal: { - auto dec = reinterpret_cast(&decodedType); - if (dec == nullptr) { - break; - } - try { - int32_t precision = std::stoi(dec->precision()->value()); - int32_t scale = std::stoi(dec->scale()->value()); - type.mutable_decimal()->set_precision(precision); - type.mutable_decimal()->set_scale(scale); - } catch (...) { - errorListener_->addError( - token, "Could not parse decimal precision and scale."); - } - type.mutable_decimal()->set_nullability(nullValue); - break; - } - case TypeKind::kStruct: { - auto structure = - reinterpret_cast(&decodedType); - for (const auto& t : structure->children()) { - *type.mutable_struct_()->add_types() = typeToProto(token, *t); - } - type.mutable_struct_()->set_nullability(nullValue); - break; - } - case TypeKind::kList: { - auto list = reinterpret_cast(&decodedType); - *type.mutable_list()->mutable_type() = - typeToProto(token, *list->elementType()); - type.mutable_list()->set_nullability(nullValue); - break; - } - case TypeKind::kMap: { - auto map = reinterpret_cast(&decodedType); - if (map->keyType() == nullptr || map->valueType() == nullptr) { - errorListener_->addError( - token, "Maps require both a key and a value type."); - break; - } - *type.mutable_map()->mutable_key() = typeToProto(token, *map->keyType()); - *type.mutable_map()->mutable_value() = - typeToProto(token, *map->valueType()); - type.mutable_map()->set_nullability(nullValue); - break; - } - case TypeKind::kKindNotSet: - errorListener_->addError(token, "Unable to recognize requested type."); - break; +int32_t SubstraitPlanRelationVisitor::visitSortDirection( + SubstraitPlanParser::IdContext* ctx) { + std::string text = normalizeProtoEnum(ctx->getText(), kSortDirectionPrefix); + if (text == "unspecified") { + return ::substrait::proto::SortField::SORT_DIRECTION_UNSPECIFIED; + } else if (text == "ascnullsfirst") { + return ::substrait::proto::SortField::SORT_DIRECTION_ASC_NULLS_FIRST; + } else if (text == "ascnullslast") { + return ::substrait::proto::SortField::SORT_DIRECTION_ASC_NULLS_LAST; + } else if (text == "descnullsfirst") { + return ::substrait::proto::SortField::SORT_DIRECTION_DESC_NULLS_FIRST; + } else if (text == "descnullslast") { + return ::substrait::proto::SortField::SORT_DIRECTION_DESC_NULLS_LAST; + } else if (text == "clustered") { + return ::substrait::proto::SortField::SORT_DIRECTION_CLUSTERED; } - return type; + this->errorListener_->addError( + ctx->getStart(), "Unrecognized sort direction: " + ctx->getText()); + return ::substrait::proto::SortField::SORT_DIRECTION_UNSPECIFIED; } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h index 17c78d66..a41e8e2b 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h @@ -3,28 +3,26 @@ #pragma once #include "SubstraitPlanParser/SubstraitPlanParser.h" -#include "SubstraitPlanParser/SubstraitPlanParserBaseVisitor.h" #include "substrait/textplan/SymbolTable.h" #include "substrait/textplan/parser/SubstraitParserErrorListener.h" -#include "substrait/type/Type.h" +#include "substrait/textplan/parser/SubstraitPlanTypeVisitor.h" namespace substrait::proto { class Expression_Literal; class Expression_Literal_Map_KeyValue; +class NamedStruct; class Type; class Type_Struct; } // namespace substrait::proto namespace io::substrait::textplan { -class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { +class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { public: SubstraitPlanRelationVisitor( const SymbolTable& symbolTable, - std::shared_ptr errorListener) { - symbolTable_ = std::make_shared(symbolTable); - errorListener_ = std::move(errorListener); - } + std::shared_ptr errorListener) + : SubstraitPlanTypeVisitor(symbolTable, std::move(errorListener)) {} [[nodiscard]] std::shared_ptr getSymbolTable() const { return symbolTable_; @@ -51,9 +49,25 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { std::any visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) override; + std::any visitRelationGrouping( + SubstraitPlanParser::RelationGroupingContext* ctx) override; + + std::any visitRelationMeasure( + SubstraitPlanParser::RelationMeasureContext* ctx) override; + + int32_t visitAggregationInvocation(SubstraitPlanParser::IdContext* ctx); + + int32_t visitAggregationPhase(SubstraitPlanParser::IdContext* ctx); + + std::any visitMeasure_detail( + SubstraitPlanParser::Measure_detailContext* ctx) override; + std::any visitRelationSourceReference( SubstraitPlanParser::RelationSourceReferenceContext* ctx) override; + std::any visitRelationSort( + SubstraitPlanParser::RelationSortContext* ctx) override; + // visitExpression is a new method delegating to the methods below. std::any visitExpression(SubstraitPlanParser::ExpressionContext* ctx); @@ -71,15 +85,6 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { std::any visitConstant(SubstraitPlanParser::ConstantContext* ctx) override; - std::any visitLiteral_basic_type( - SubstraitPlanParser::Literal_basic_typeContext* ctx) override; - - std::any visitLiteral_complex_type( - SubstraitPlanParser::Literal_complex_typeContext* ctx) override; - - std::any visitLiteral_specifier( - SubstraitPlanParser::Literal_specifierContext* ctx) override; - std::any visitMap_literal( SubstraitPlanParser::Map_literalContext* ctx) override; @@ -92,6 +97,9 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { std::any visitColumn_name( SubstraitPlanParser::Column_nameContext* ctx) override; + std::any visitSort_field( + SubstraitPlanParser::Sort_fieldContext* ctx) override; + ::substrait::proto::Expression_Literal visitConstantWithType( SubstraitPlanParser::ConstantContext* ctx, const ::substrait::proto::Type& literalType); @@ -147,22 +155,13 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { ::substrait::proto::Expression_Literal visitTime( SubstraitPlanParser::ConstantContext* ctx); + int32_t visitSortDirection(SubstraitPlanParser::IdContext* ctx); + private: std::string escapeText( const antlr4::tree::TerminalNode* node, const std::string& str); - ::substrait::proto::Type textToTypeProto( - const antlr4::Token* token, - const std::string& typeText); - - ::substrait::proto::Type typeToProto( - const antlr4::Token* token, - const ParameterizedType& decodedType); - - std::shared_ptr symbolTable_; - std::shared_ptr errorListener_; - const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. }; diff --git a/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp new file mode 100644 index 00000000..403caa22 --- /dev/null +++ b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.cpp @@ -0,0 +1,200 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "SubstraitPlanTypeVisitor.h" + +#include +#include +#include "SubstraitPlanParser/SubstraitPlanParser.h" +#include "substrait/proto/type.pb.h" +#include "substrait/textplan/SymbolTable.h" +#include "substrait/type/Type.h" + +namespace io::substrait::textplan { + +std::any SubstraitPlanTypeVisitor::visitLiteral_basic_type( + SubstraitPlanParser::Literal_basic_typeContext* ctx) { + return textToTypeProto(ctx, ctx->getText()); +} + +std::any SubstraitPlanTypeVisitor::visitLiteral_complex_type( + SubstraitPlanParser::Literal_complex_typeContext* ctx) { + return textToTypeProto(ctx, ctx->getText()); +} + +::substrait::proto::Type SubstraitPlanTypeVisitor::textToTypeProto( + const antlr4::ParserRuleContext* ctx, + const std::string& typeText) { + std::shared_ptr decodedType; + try { + decodedType = Type::decode(typeText); + } catch (...) { + errorListener_->addError(ctx->getStart(), "Failed to decode type."); + return ::substrait::proto::Type{}; + } + return typeToProto(ctx, *decodedType); +} + +::substrait::proto::Type SubstraitPlanTypeVisitor::typeToProto( + const antlr4::ParserRuleContext* ctx, + const ParameterizedType& decodedType) { + ::substrait::proto::Type type; + auto nullValue = ::substrait::proto::Type_Nullability_NULLABILITY_REQUIRED; + if (decodedType.nullable()) { + nullValue = ::substrait::proto::Type_Nullability_NULLABILITY_NULLABLE; + } + switch (decodedType.kind()) { + case TypeKind::kBool: + type.mutable_bool_()->set_nullability(nullValue); + break; + case TypeKind::kI8: + type.mutable_i8()->set_nullability(nullValue); + break; + case TypeKind::kI16: + type.mutable_i16()->set_nullability(nullValue); + break; + case TypeKind::kI32: + type.mutable_i32()->set_nullability(nullValue); + break; + case TypeKind::kI64: + type.mutable_i64()->set_nullability(nullValue); + break; + case TypeKind::kFp32: + type.mutable_fp32()->set_nullability(nullValue); + break; + case TypeKind::kFp64: + type.mutable_fp64()->set_nullability(nullValue); + break; + case TypeKind::kString: + type.mutable_string()->set_nullability(nullValue); + break; + case TypeKind::kBinary: + type.mutable_binary()->set_nullability(nullValue); + break; + case TypeKind::kTimestamp: + type.mutable_timestamp()->set_nullability(nullValue); + break; + case TypeKind::kDate: + type.mutable_date()->set_nullability(nullValue); + break; + case TypeKind::kTime: + type.mutable_time()->set_nullability(nullValue); + break; + case TypeKind::kIntervalYear: + type.mutable_interval_year()->set_nullability(nullValue); + break; + case TypeKind::kIntervalDay: + type.mutable_interval_day()->set_nullability(nullValue); + break; + case TypeKind::kTimestampTz: + type.mutable_timestamp_tz()->set_nullability(nullValue); + break; + case TypeKind::kUuid: + type.mutable_uuid()->set_nullability(nullValue); + break; + case TypeKind::kFixedChar: { + auto fixedChar = + reinterpret_cast(&decodedType); + if (fixedChar == nullptr) { + break; + } + try { + int32_t length = std::stoi(fixedChar->length()->value()); + type.mutable_fixed_char()->set_length(length); + } catch (...) { + errorListener_->addError( + ctx->getStart(), "Could not parse fixedchar length."); + } + type.mutable_fixed_char()->set_nullability(nullValue); + break; + } + case TypeKind::kVarchar: { + auto varChar = + reinterpret_cast(&decodedType); + if (varChar == nullptr) { + break; + } + try { + int32_t length = std::stoi(varChar->length()->value()); + type.mutable_varchar()->set_length(length); + } catch (...) { + errorListener_->addError( + ctx->getStart(), "Could not parse varchar length."); + } + type.mutable_varchar()->set_nullability(nullValue); + break; + } + case TypeKind::kFixedBinary: + type.mutable_fixed_binary()->set_nullability(nullValue); + break; + case TypeKind::kDecimal: { + auto dec = reinterpret_cast(&decodedType); + if (dec == nullptr) { + break; + } + try { + int32_t precision = std::stoi(dec->precision()->value()); + int32_t scale = std::stoi(dec->scale()->value()); + type.mutable_decimal()->set_precision(precision); + type.mutable_decimal()->set_scale(scale); + } catch (...) { + errorListener_->addError( + ctx->getStart(), "Could not parse decimal precision and scale."); + } + type.mutable_decimal()->set_nullability(nullValue); + break; + } + case TypeKind::kStruct: { + auto structure = + reinterpret_cast(&decodedType); + for (const auto& t : structure->children()) { + *type.mutable_struct_()->add_types() = typeToProto(ctx, *t); + } + type.mutable_struct_()->set_nullability(nullValue); + break; + } + case TypeKind::kList: { + auto list = reinterpret_cast(&decodedType); + *type.mutable_list()->mutable_type() = + typeToProto(ctx, *list->elementType()); + type.mutable_list()->set_nullability(nullValue); + break; + } + case TypeKind::kMap: { + auto map = reinterpret_cast(&decodedType); + if (map->keyType() == nullptr || map->valueType() == nullptr) { + errorListener_->addError( + ctx->getStart(), "Maps require both a key and a value type."); + break; + } + *type.mutable_map()->mutable_key() = typeToProto(ctx, *map->keyType()); + *type.mutable_map()->mutable_value() = + typeToProto(ctx, *map->valueType()); + type.mutable_map()->set_nullability(nullValue); + break; + } + case TypeKind::kKindNotSet: + if (!insideStructLiteralWithExternalType(ctx)) { + errorListener_->addError( + ctx->getStart(), "Unable to recognize requested type."); + } + break; + } + return type; +} + +bool SubstraitPlanTypeVisitor::insideStructLiteralWithExternalType( + const antlr4::RuleContext* ctx) { + if (ctx == nullptr) { + return false; + } + if (ctx->getRuleIndex() == SubstraitPlanParser::RuleConstant && + const_cast( + dynamic_cast(ctx)) + ->struct_literal() != nullptr) { + return true; + } + return insideStructLiteralWithExternalType( + dynamic_cast(ctx->parent)); +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.h b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.h new file mode 100644 index 00000000..6370b2be --- /dev/null +++ b/src/substrait/textplan/parser/SubstraitPlanTypeVisitor.h @@ -0,0 +1,50 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "SubstraitPlanParser/SubstraitPlanParser.h" +#include "SubstraitPlanParser/SubstraitPlanParserBaseVisitor.h" +#include "substrait/textplan/SymbolTable.h" +#include "substrait/textplan/parser/SubstraitParserErrorListener.h" +#include "substrait/type/Type.h" + +namespace substrait::proto { +class Type; +} + +namespace io::substrait::textplan { + +class SubstraitPlanTypeVisitor : public SubstraitPlanParserBaseVisitor { + public: + SubstraitPlanTypeVisitor( + const SymbolTable& symbolTable, + std::shared_ptr errorListener) { + symbolTable_ = std::make_shared(symbolTable); + errorListener_ = std::move(errorListener); + } + + std::any visitLiteral_basic_type( + SubstraitPlanParser::Literal_basic_typeContext* ctx) override; + std::any visitLiteral_complex_type( + SubstraitPlanParser::Literal_complex_typeContext* ctx) override; + + protected: + ::substrait::proto::Type textToTypeProto( + const antlr4::ParserRuleContext* ctx, + const std::string& typeText); + + ::substrait::proto::Type typeToProto( + const antlr4::ParserRuleContext* ctx, + const ParameterizedType& decodedType); + + // Identifies whether the given context has a parent node of a constant + // including a struct. This allows {3years, 1month, + // 2days}_interval_year_month_day to have the optional label tags which are + // not real types. + bool insideStructLiteralWithExternalType(const antlr4::RuleContext* ctx); + + std::shared_ptr symbolTable_; + std::shared_ptr errorListener_; +}; + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp index 30dfb3a6..33657ad3 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp @@ -9,6 +9,7 @@ #include "substrait/textplan/Finally.h" #include "substrait/textplan/Location.h" #include "substrait/textplan/StructuredSymbolData.h" +#include "substrait/type/Type.h" namespace io::substrait::textplan { @@ -51,10 +52,14 @@ std::any SubstraitPlanVisitor::visitPipeline( std::any SubstraitPlanVisitor::visitExtensionspace( SubstraitPlanParser::ExtensionspaceContext* ctx) { + if (ctx->URI() == nullptr) { + // Nothing to keep track of at this level. + return visitChildren(ctx); + } + const std::string& uri = ctx->URI()->getText(); // TODO -- Transition to querying the symbol table for the space number. #42 - static uint32_t numSpacesSeen = 0; - uint32_t thisSpace = numSpacesSeen++; + uint32_t thisSpace = ++numSpacesSeen_; symbolTable_->defineSymbol( uri, Location(ctx), @@ -65,9 +70,9 @@ std::any SubstraitPlanVisitor::visitExtensionspace( // Update the contained functions to belong in this space. for (auto func : ctx->function()) { - auto funcSymbol = symbolTable_->lookupSymbolByLocation(Location(func)); + auto* funcSymbol = symbolTable_->lookupSymbolByLocation(Location(func)); auto functionData = - ANY_CAST(std::shared_ptr, funcSymbol.blob); + ANY_CAST(std::shared_ptr, funcSymbol->blob); functionData->extensionUriReference = thisSpace; } @@ -78,13 +83,36 @@ std::any SubstraitPlanVisitor::visitFunction( SubstraitPlanParser::FunctionContext* ctx) { // TODO -- Transition to using the symbol table for the function number. #42 // Let our enclosing extension space provide us with the detail. + std::string referenceName; + if (ctx->id() != nullptr) { + referenceName = ctx->id()->getText(); + } else if (ctx->name() != nullptr) { + referenceName = ctx->name()->getText(); + auto colonPos = referenceName.find_first_of(':'); + if (colonPos != std::string::npos) { + referenceName = referenceName.substr(0, colonPos); + } + } else { + referenceName = ""; + } + + // We do not yet examine the type of functions but we look for presence. + if (ctx->name() != nullptr) { + auto colonPos = ctx->name()->getText().find_first_of(':'); + if (colonPos == std::string::npos || + ctx->name()->getText().substr(colonPos + 1).empty()) { + errorListener_->addError( + ctx->getStart(), "Functions should have an associated type."); + } + } + symbolTable_->defineSymbol( - ctx->id()->getText(), + referenceName, Location(ctx), SymbolType::kFunction, defaultResult(), std::make_shared( - ctx->name()->getText(), std::nullopt, ++numFunctionsSeen_)); + ctx->name()->getText(), std::nullopt, numFunctionsSeen_++)); return visitChildren(ctx); } @@ -107,9 +135,10 @@ std::any SubstraitPlanVisitor::visitSchema_definition( defaultResult(), defaultResult()); + // Mark all of the schema items so we can find the ones related to this + // schema. for (const auto& item : ctx->schema_item()) { auto symbol = ANY_CAST(SymbolInfo*, visitSchema_item(item)); - // TODO -- Implement schemas instead of skipping them. if (symbol == nullptr) { continue; } @@ -119,16 +148,6 @@ std::any SubstraitPlanVisitor::visitSchema_definition( return nullptr; } -std::any SubstraitPlanVisitor::visitColumn_attribute( - SubstraitPlanParser::Column_attributeContext* ctx) { - return visitChildren(ctx); -} - -std::any SubstraitPlanVisitor::visitColumn_type( - SubstraitPlanParser::Column_typeContext* ctx) { - return visitChildren(ctx); -} - std::any SubstraitPlanVisitor::visitSchema_item( SubstraitPlanParser::Schema_itemContext* ctx) { return symbolTable_->defineSymbol( @@ -136,7 +155,7 @@ std::any SubstraitPlanVisitor::visitSchema_item( Location(ctx), SymbolType::kSchemaColumn, defaultResult(), - ctx->column_type()->getText()); + visitLiteral_complex_type(ctx->literal_complex_type())); } std::any SubstraitPlanVisitor::visitRelation( @@ -223,16 +242,6 @@ std::any SubstraitPlanVisitor::visitLiteral_specifier( return visitChildren(ctx); } -std::any SubstraitPlanVisitor::visitLiteral_basic_type( - SubstraitPlanParser::Literal_basic_typeContext* ctx) { - return visitChildren(ctx); -} - -std::any SubstraitPlanVisitor::visitLiteral_complex_type( - SubstraitPlanParser::Literal_complex_typeContext* ctx) { - return visitChildren(ctx); -} - std::any SubstraitPlanVisitor::visitMap_literal_value( SubstraitPlanParser::Map_literal_valueContext* ctx) { return visitChildren(ctx); @@ -250,7 +259,8 @@ std::any SubstraitPlanVisitor::visitStruct_literal( std::any SubstraitPlanVisitor::visitConstant( SubstraitPlanParser::ConstantContext* ctx) { - return visitChildren(ctx); + // No need to examine these just yet, we will do this in the next pass. + return defaultResult(); } std::any SubstraitPlanVisitor::visitColumn_name( @@ -311,6 +321,11 @@ std::any SubstraitPlanVisitor::visitRelation_filter_behavior( return visitChildren(ctx); } +std::any SubstraitPlanVisitor::visitMeasure_detail( + SubstraitPlanParser::Measure_detailContext* ctx) { + return visitChildren(ctx); +} + std::any SubstraitPlanVisitor::visitRelationFilter( SubstraitPlanParser::RelationFilterContext* ctx) { return visitChildren(ctx); @@ -318,7 +333,8 @@ std::any SubstraitPlanVisitor::visitRelationFilter( std::any SubstraitPlanVisitor::visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) { - return visitChildren(ctx); + visitChildren(ctx); + return nullptr; } std::any SubstraitPlanVisitor::visitRelationAdvancedExtension( @@ -331,6 +347,31 @@ std::any SubstraitPlanVisitor::visitRelationSourceReference( return visitChildren(ctx); } +std::any SubstraitPlanVisitor::visitRelationGrouping( + SubstraitPlanParser::RelationGroupingContext* ctx) { + return visitChildren(ctx); +} + +std::any SubstraitPlanVisitor::visitRelationMeasure( + SubstraitPlanParser::RelationMeasureContext* ctx) { + return visitChildren(ctx); +} + +std::any SubstraitPlanVisitor::visitRelationSort( + SubstraitPlanParser::RelationSortContext* ctx) { + return visitChildren(ctx); +} + +std::any SubstraitPlanVisitor::visitRelationCount( + SubstraitPlanParser::RelationCountContext* ctx) { + return visitChildren(ctx); +} + +std::any SubstraitPlanVisitor::visitRelationJoinType( + SubstraitPlanParser::RelationJoinTypeContext* ctx) { + return visitChildren(ctx); +} + std::any SubstraitPlanVisitor::visitFile_location( SubstraitPlanParser::File_locationContext* ctx) { return visitChildren(ctx); @@ -406,10 +447,20 @@ std::any SubstraitPlanVisitor::visitRelation_ref( return rel; } +std::any SubstraitPlanVisitor::visitSort_field( + SubstraitPlanParser::Sort_fieldContext* ctx) { + return defaultResult(); +} + std::any SubstraitPlanVisitor::visitId(SubstraitPlanParser::IdContext* ctx) { return ctx->getText(); } +std::any SubstraitPlanVisitor::visitSimple_id( + SubstraitPlanParser::Simple_idContext* ctx) { + return defaultResult(); +} + // NOLINTEND(readability-convert-member-functions-to-static) // NOLINTEND(readability-identifier-naming) diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.h b/src/substrait/textplan/parser/SubstraitPlanVisitor.h index c973024d..84f36b3e 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.h @@ -3,18 +3,18 @@ #pragma once #include "SubstraitPlanParser/SubstraitPlanParser.h" -#include "SubstraitPlanParser/SubstraitPlanParserVisitor.h" #include "substrait/textplan/SymbolTable.h" #include "substrait/textplan/parser/SubstraitParserErrorListener.h" +#include "substrait/textplan/parser/SubstraitPlanTypeVisitor.h" namespace io::substrait::textplan { -class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { +class SubstraitPlanVisitor : public SubstraitPlanTypeVisitor { public: - SubstraitPlanVisitor() { - symbolTable_ = std::make_shared(); - errorListener_ = std::make_shared(); - } + SubstraitPlanVisitor( + const SymbolTable& symbolTable, + std::shared_ptr errorListener) + : SubstraitPlanTypeVisitor(symbolTable, std::move(errorListener)) {} [[nodiscard]] std::shared_ptr getSymbolTable() const { return symbolTable_; @@ -37,10 +37,6 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { std::any visitSignature(SubstraitPlanParser::SignatureContext* ctx) override; std::any visitSchema_definition( SubstraitPlanParser::Schema_definitionContext* ctx) override; - std::any visitColumn_attribute( - SubstraitPlanParser::Column_attributeContext* ctx) override; - std::any visitColumn_type( - SubstraitPlanParser::Column_typeContext* ctx) override; std::any visitSchema_item( SubstraitPlanParser::Schema_itemContext* ctx) override; std::any visitRelation(SubstraitPlanParser::RelationContext* ctx) override; @@ -50,10 +46,6 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { SubstraitPlanParser::Source_definitionContext* ctx) override; std::any visitLiteral_specifier( SubstraitPlanParser::Literal_specifierContext* ctx) override; - std::any visitLiteral_basic_type( - SubstraitPlanParser::Literal_basic_typeContext* ctx) override; - std::any visitLiteral_complex_type( - SubstraitPlanParser::Literal_complex_typeContext* ctx) override; std::any visitMap_literal_value( SubstraitPlanParser::Map_literal_valueContext* ctx) override; std::any visitMap_literal( @@ -79,6 +71,8 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { SubstraitPlanParser::RelationUsesSchemaContext* ctx) override; std::any visitRelation_filter_behavior( SubstraitPlanParser::Relation_filter_behaviorContext* ctx) override; + std::any visitMeasure_detail( + SubstraitPlanParser::Measure_detailContext* ctx) override; std::any visitRelationFilter( SubstraitPlanParser::RelationFilterContext* ctx) override; std::any visitRelationExpression( @@ -87,6 +81,16 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { SubstraitPlanParser::RelationAdvancedExtensionContext* ctx) override; std::any visitRelationSourceReference( SubstraitPlanParser::RelationSourceReferenceContext* ctx) override; + std::any visitRelationGrouping( + SubstraitPlanParser::RelationGroupingContext* ctx) override; + std::any visitRelationMeasure( + SubstraitPlanParser::RelationMeasureContext* ctx) override; + std::any visitRelationSort( + SubstraitPlanParser::RelationSortContext* ctx) override; + std::any visitRelationCount( + SubstraitPlanParser::RelationCountContext* ctx) override; + std::any visitRelationJoinType( + SubstraitPlanParser::RelationJoinTypeContext* ctx) override; std::any visitFile_location( SubstraitPlanParser::File_locationContext* ctx) override; std::any visitFile_detail( @@ -106,14 +110,15 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { SubstraitPlanParser::Named_table_detailContext* ctx) override; std::any visitRelation_ref( SubstraitPlanParser::Relation_refContext* ctx) override; + std::any visitSort_field( + SubstraitPlanParser::Sort_fieldContext* ctx) override; std::any visitId(SubstraitPlanParser::IdContext* ctx) override; + std::any visitSimple_id(SubstraitPlanParser::Simple_idContext* ctx) override; private: - std::shared_ptr symbolTable_; - std::shared_ptr errorListener_; - const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. + int numSpacesSeen_{0}; int numFunctionsSeen_{0}; }; diff --git a/src/substrait/textplan/parser/Tool.cpp b/src/substrait/textplan/parser/Tool.cpp index 9c980ec6..84302d9c 100644 --- a/src/substrait/textplan/parser/Tool.cpp +++ b/src/substrait/textplan/parser/Tool.cpp @@ -1,7 +1,7 @@ /* SPDX-License-Identifier: Apache-2.0 */ #include -#include +#include #include "substrait/textplan/SymbolTablePrinter.h" #include "substrait/textplan/parser/ParseText.h" diff --git a/src/substrait/textplan/parser/data/provided_sample1.splan b/src/substrait/textplan/parser/data/provided_sample1.splan index 80b78216..ed3a0dd3 100644 --- a/src/substrait/textplan/parser/data/provided_sample1.splan +++ b/src/substrait/textplan/parser/data/provided_sample1.splan @@ -19,7 +19,7 @@ read relation read { schema schema { r_regionkey i32; r_name string; - r_comment string; + r_comment string?; } source named_table named { diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 index c8e748c2..e96edf5f 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 @@ -28,7 +28,6 @@ FUNCTION: 'FUNCTION'; AS: 'AS'; SCHEMA: 'SCHEMA'; RELATION: 'RELATION'; -NULLABLE: 'NULLABLE'; PIPELINES: 'PIPELINES'; COMMON: 'COMMON'; @@ -37,6 +36,13 @@ FILTER: 'FILTER'; PROJECTION: 'PROJECTION'; EXPRESSION: 'EXPRESSION'; ADVANCED_EXTENSION: 'ADVANCED_EXTENSION'; +GROUPING: 'GROUPING'; +MEASURE: 'MEASURE'; +INVOCATION: 'INVOCATION'; +SORT: 'SORT'; +BY: 'BY'; +COUNT: 'COUNT'; +TYPE: 'TYPE'; VIRTUAL_TABLE: 'VIRTUAL_TABLE'; LOCAL_FILES: 'LOCAL_FILES'; @@ -53,6 +59,7 @@ URI_FOLDER: 'URI_FOLDER'; PARTITION_INDEX: 'PARTITION_INDEX'; START: 'START'; LENGTH: 'LENGTH'; +ORC: 'ORC'; NULLVAL: 'NULL'; TRUEVAL: 'TRUE'; FALSEVAL: 'FALSE'; @@ -79,6 +86,7 @@ MINUS: '-'; LEFTANGLEBRACKET: '<'; RIGHTANGLEBRACKET: '>'; QUESTIONMARK: '?'; +ATSIGN: '@'; IDENTIFIER : [A-Z][A-Z0-9]* @@ -105,11 +113,12 @@ SPACES: [ \u000B\t\r\n] -> channel(HIDDEN); mode EXTENSIONS; fragment SCHEME: [A-Z]+ ; fragment HOSTNAME: [A-Z0-9-.]+ ; -fragment FILENAME: [A-Z0-9-.]+; +fragment FILENAME: [A-Z0-9-._]+; +fragment PATH: FILENAME ( '/' FILENAME )*; URI - : SCHEME ':' ( '//' HOSTNAME '/')? FILENAME - | FILENAME + : SCHEME ':' ( '//' HOSTNAME '/' )? PATH + | '/'? PATH ; EXTENSIONS_LEFTBRACE: '{' -> mode(DEFAULT_MODE), type(LEFTBRACE); diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 index 851029ee..f5520241 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 @@ -56,20 +56,34 @@ relation_filter_behavior | id id ; +// TODO -- Can the type be determined automatically from the function definition? +// TODO -- Consider moving the run phase to an optional third detail line. +measure_detail + : MEASURE expression (ARROW literal_complex_type)? (ATSIGN id)? SEMICOLON + | FILTER expression SEMICOLON + | INVOCATION id SEMICOLON + | sort_field + ; + relation_detail : COMMON SEMICOLON # relationCommon | BASE_SCHEMA id SEMICOLON # relationUsesSchema | relation_filter_behavior? FILTER expression SEMICOLON # relationFilter - | EXPRESSION expression SEMICOLON (AS id)? # relationExpression - | ADVANCED_EXTENSION SEMICOLON # relationAdvancedExtension - | source_reference SEMICOLON # relationSourceReference + | EXPRESSION expression SEMICOLON (AS id)? # relationExpression + | ADVANCED_EXTENSION SEMICOLON # relationAdvancedExtension + | source_reference SEMICOLON # relationSourceReference + | GROUPING expression SEMICOLON # relationGrouping + | MEASURE LEFTBRACE measure_detail* RIGHTBRACE # relationMeasure + | sort_field # relationSort + | COUNT NUMBER SEMICOLON # relationCount + | TYPE id SEMICOLON # relationJoinType ; expression - : id LEFTPAREN expression (COMMA expression)? COMMA? RIGHTPAREN # expressionFunctionUse - | constant # expressionConstant - | column_name # expressionColumn - | expression AS literal_complex_type # expressionCast + : id LEFTPAREN (expression COMMA?)* RIGHTPAREN # expressionFunctionUse + | constant # expressionConstant + | column_name # expressionColumn + | expression AS literal_complex_type # expressionCast ; constant @@ -83,7 +97,7 @@ constant ; literal_basic_type - : id literal_specifier? QUESTIONMARK? + : id QUESTIONMARK? literal_specifier? ; literal_complex_type @@ -130,6 +144,7 @@ file_detail : PARTITION_INDEX COLON NUMBER | START COLON NUMBER | LENGTH COLON NUMBER + | ORC COLON LEFTBRACE RIGHTBRACE | file_location ; @@ -152,15 +167,7 @@ schema_definition ; schema_item - : id column_type SEMICOLON - ; - -column_type - : column_attribute* id - ; - -column_attribute - : NULLABLE + : id literal_complex_type SEMICOLON ; source_definition @@ -182,8 +189,12 @@ function : FUNCTION name (AS id)? SEMICOLON ; +sort_field + : SORT expression (BY id)? SEMICOLON + ; + name - : id COLON signature + : id COLON signature? ; signature @@ -192,7 +203,17 @@ signature // List keywords here to make them not reserved. id - : IDENTIFIER (UNDERSCORE+ IDENTIFIER)* + : simple_id (UNDERSCORE+ simple_id)* + ; + +simple_id + : IDENTIFIER | FILTER | SCHEMA + | NULLVAL + | SORT + | MEASURE + | GROUPING + | COUNT + | TYPE ; diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index 901753c4..13e8eef6 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -161,9 +161,9 @@ std::vector getTestCases() { { "test3-schema", R"(schema schema { - r_regionkey UNKNOWN; - r_name nullable UNKNOWN; - r_comment UNKNOWN; + r_regionkey i32; + r_name string?; + r_comment string; })", AllOf( HasSymbols({"schema", "r_regionkey", "r_name", "r_comment"}), @@ -232,7 +232,24 @@ std::vector getTestCases() { function concat:str as concat; })")), AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( - R"(relations { root { input { project { + R"(extension_uris { + extension_uri_anchor: 1 uri: "blah.yaml" + } + extensions { + extension_function { + extension_uri_reference: 1 name: "add:i8" } + } + extensions { + extension_function { + extension_uri_reference: 1 function_anchor: 1 + name: "subtract:i8" } + } + extensions { + extension_function { + extension_uri_reference: 1 function_anchor: 2 + name: "concat:str" } + } + relations { root { input { project { expressions { selection { direct_reference { @@ -261,15 +278,15 @@ std::vector getTestCases() { } } expressions { scalar_function { - function_reference: 1 arguments { value { selection { + function_reference: 0 arguments { value { selection { direct_reference { struct_field { } } } } } arguments { value { literal { i8: 1 } } } } } expressions { scalar_function { - function_reference: 2 arguments { value { selection { + function_reference: 1 arguments { value { selection { direct_reference { struct_field { } } } } } arguments { value { literal { i8: 1 } } } } } expressions { scalar_function { - function_reference: 3 arguments { value { selection { + function_reference: 2 arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } arguments { value { selection { direct_reference { struct_field { field: 1 } } } } } } } @@ -291,6 +308,16 @@ std::vector getTestCases() { })", AllOf(HasSymbols({"myread"}), ParsesOk()), }, + { + "test7-bad-filter-relation", + R"(filter relation filter { + condition true_bool; + })", + HasErrors( + {"2:22 → missing 'FILTER' at 'true'", + "2:12 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", + "2:12 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior."}), + }, { "test10-literals-boolean", R"(project relation literalexamples { @@ -414,8 +441,10 @@ std::vector getTestCases() { expression "two\nlines with \"escapes\""_varchar<80>; expression "abcde"_fixedchar<5>; })", - AsBinaryPlan(Partially(EqualsProto<::substrait::proto::Plan>( - R"(relations { root { input { project { + AllOf( + HasErrors({}), + AsBinaryPlan(Partially(EqualsProto<::substrait::proto::Plan>( + R"(relations { root { input { project { expressions { literal { string: "simple text" } } expressions { literal { string: "123" } } expressions { literal { @@ -434,7 +463,7 @@ std::vector getTestCases() { var_char: { value: "two\nlines with \"escapes\"" length: 80 } } } expressions { literal { fixed_char: "abcde" } } - } } } })"))), + } } } })")))), }, { "test10-literals-strings-nulls", @@ -554,7 +583,7 @@ std::vector getTestCases() { R"(project relation literalexamples { expression {"a", "b", "c"}_list; expression {null, "a", "b"}_list; - expression {{"a", "b"}, {"1", "2"}}_list>?; + expression {{"a", "b"}, {"1", "2"}}_list?>; expression {}_list; expression {}_list; expression {}_list?; @@ -578,10 +607,14 @@ std::vector getTestCases() { values { string: "1" } values { string: "2" } } } } } } - expressions { literal { empty_list { type { string { } } } } } expressions { literal { empty_list { type { string { - nullability: NULLABILITY_NULLABLE } } } } } - expressions { literal { empty_list { type { string { } } + nullability: NULLABILITY_REQUIRED } } + nullability: NULLABILITY_REQUIRED } } } + expressions { literal { empty_list { type { string { + nullability: NULLABILITY_NULLABLE } } + nullability: NULLABILITY_REQUIRED } } } + expressions { literal { empty_list { type { string { + nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } } } })")), }, @@ -595,17 +628,17 @@ std::vector getTestCases() { })", AsBinaryPlan((EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { - expressions { literal { null { list { type { string { } } + expressions { literal { null { list { type { string { nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } expressions { literal { null { list { type { list { type { string { - nullability: NULLABILITY_NULLABLE } } } } + nullability: NULLABILITY_NULLABLE } } nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } expressions { literal { null { list { type { - list { type { string { } } } } + list { type { string { nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } expressions { literal { null { list { type { - list { type { string { } } } } + list { type { string { nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } } } } })"))), }, @@ -622,7 +655,7 @@ std::vector getTestCases() { key_values { key { i16: 32 } value { string: "everything" } } } } } expressions { literal { - empty_map { key { fp32 {} } value { string { } } } } } + empty_map { key { fp32 { nullability: NULLABILITY_REQUIRED } } value { string {nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_REQUIRED} } } } } } })")), }, { @@ -681,18 +714,56 @@ std::vector getTestCases() { }} } })"))), }, { - "test11-bad-literals", + "test11-bad-numeric-literals", R"(project relation literalexamples { expression 1; expression 1.5; expression 1_potato; - expression "data"_potato; expression null; - expression "ddb287e8"_uuid; - expression "nothex"_uuid; expression 42_decimal; expression 42_decimal; expression 42_decimal<-5,-4>; + })", + HasErrors({ + "6:34 → mismatched input 'r5' expecting NUMBER", + "6:36 → mismatched input ',' expecting 'FILTER'", + "7:34 → mismatched input 'r' expecting NUMBER", + "7:35 → mismatched input ',' expecting 'FILTER'", + "2:23 → Literals should include a type.", + "3:23 → Literals should include a type.", + "4:25 → Unable to recognize requested type.", + "5:23 → Null literals require type.", + "6:26 → Failed to decode type.", + "6:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", + "6:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", + "6:34 → Filters are not permitted for this kind of relation.", + "7:26 → Failed to decode type.", + "7:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", + "7:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", + "7:34 → Filters are not permitted for this kind of relation.", + "8:23 → Could not parse literal as decimal.", + }), + }, + { + "test11-bad-stringlike-literals", + R"(project relation literalexamples { + expression "data"_potato; + expression "ddb287e8"_uuid; + expression "nothex"_uuid; + expression "unknown\escape"_string; + expression "abcde"_fixedchar; + })", + HasErrors({ + "2:30 → Unable to recognize requested type.", + "3:23 → UUIDs are 128 bits long and thus should be specified with exactly 32 hexadecimal digits.", + "4:23 → UUIDs should be be specified with hexadecimal characters with optional dashes only.", + "5:31 → Unknown slash escape sequence.", + "6:31 → Unable to recognize requested type.", + }), + }, + { + "test11-bad-complex-literals", + R"(project relation literalexamples { expression {}_list?; expression {}_struct; expression {}_struct<>; @@ -701,45 +772,23 @@ std::vector getTestCases() { expression {}_map<,string>; expression {}_map<,>; expression {}_list<>; - expression "unknown\escape"_string; expression {123_i8}_map; expression {123}_map; })", HasErrors({ - "9:34 → mismatched input 'r5' expecting NUMBER", - "9:36 → mismatched input ',' expecting 'FILTER'", - "10:34 → mismatched input 'r' expecting NUMBER", - "10:35 → mismatched input ',' expecting 'FILTER'", - "12:38 → extraneous input '?' expecting ';'", - "2:23 → Literals should include a type.", - "3:23 → Literals should include a type.", - "4:25 → Unable to recognize requested type.", - "5:30 → Unable to recognize requested type.", - "6:23 → Null literals require type.", - "7:23 → UUIDs are 128 bits long and thus should be specified with exactly 32 hexadecimal digits.", - "8:23 → UUIDs should be be specified with hexadecimal characters with optional dashes only.", - "9:26 → Failed to decode type.", - "9:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", - "9:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", - "9:34 → Filters are not permitted for this kind of relation.", - "10:26 → Failed to decode type.", - "10:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", - "10:34 → Best effort and post join are the only two legal filter behavior choices. You may also not provide one which will result to the default filter behavior.", - "10:34 → Filters are not permitted for this kind of relation.", - "11:23 → Could not parse literal as decimal.", - "13:26 → Unable to recognize requested type.", - "14:26 → Unable to recognize requested type.", - "15:26 → Maps require both a key and a value type.", - "15:23 → Unsupported type 0.", - "16:26 → Maps require both a key and a value type.", - "16:23 → Unsupported type 0.", - "17:26 → Unable to recognize requested type.", - "18:26 → Unable to recognize requested type.", - "18:26 → Unable to recognize requested type.", - "19:26 → Unable to recognize requested type.", - "20:31 → Unknown slash escape sequence.", - "21:23 → Map literals require pairs of values separated by colons.", - "22:23 → Map literals require pairs of values separated by colons.", + "2:38 → extraneous input '?' expecting ';'", + "3:26 → Unable to recognize requested type.", + "4:26 → Unable to recognize requested type.", + "5:26 → Maps require both a key and a value type.", + "5:23 → Unsupported type 0.", + "6:26 → Maps require both a key and a value type.", + "6:23 → Unsupported type 0.", + "7:26 → Unable to recognize requested type.", + "8:26 → Unable to recognize requested type.", + "8:26 → Unable to recognize requested type.", + "9:26 → Unable to recognize requested type.", + "10:23 → Map literals require pairs of values separated by colons.", + "11:23 → Map literals require pairs of values separated by colons.", }), }, { @@ -752,10 +801,13 @@ std::vector getTestCases() { HasErrors({}), AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { - expressions { cast { type { i32 {} } + expressions { cast { type { i32 { + nullability: NULLABILITY_REQUIRED } } input { literal { i8: 123 } } } } - expressions { cast { type { i64 {} } - input { cast { type { i32 {} } + expressions { cast { type { i64 { + nullability: NULLABILITY_REQUIRED } } + input { cast { type { i32 { + nullability: NULLABILITY_REQUIRED } } input { literal { i8: 123 } } } } } } } } } })"))), }, @@ -771,8 +823,8 @@ std::vector getTestCases() { })", AsBinaryPlan(Partially(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { - expressions { scalar_function { function_reference: 1 - arguments { value { scalar_function { function_reference: 2 + expressions { scalar_function { function_reference: 0 + arguments { value { scalar_function { function_reference: 1 arguments { value { literal { i32: 1 } } } arguments { value { literal { i32: -2 } } } } } } @@ -780,6 +832,13 @@ std::vector getTestCases() { } } } } } })"))), }, + { + "test13-bad-functions", + R"(extension_space blah.yaml { + function sum: as sum; + })", + HasErrors({"2:12 → Functions should have an associated type."}), + }, { "test14-three-node-pipeline-with-fields", R"(pipelines { @@ -867,9 +926,9 @@ std::vector getTestCases() { names: "product_id" names: "count" struct { - types { i32 { } } - types { i32 { } } - types { i64 { } } } + types { i32 { nullability: NULLABILITY_REQUIRED } } + types { i32 { nullability: NULLABILITY_REQUIRED } } + types { i64 { nullability: NULLABILITY_REQUIRED } } } } named_table { names: "#1" } } @@ -880,8 +939,8 @@ std::vector getTestCases() { names: "product_id" names: "cost" struct { - types { i32 { } } - types { fp32 { } } } + types { i32 { nullability: NULLABILITY_REQUIRED } } + types { fp32 { nullability: NULLABILITY_REQUIRED } } } } named_table { names: "#2" } } @@ -912,8 +971,8 @@ std::vector getTestCases() { names: "company" names: "order_id" struct { - types { string { } } - types { i32 { } } + types { string { nullability: NULLABILITY_REQUIRED } } + types { i32 { nullability: NULLABILITY_REQUIRED } } } } named_table { names: "#3" } @@ -941,7 +1000,8 @@ std::vector getTestCases() { HasErrors({ "1:0 → extraneous input 'relation' expecting {, " "'EXTENSION_SPACE', 'SCHEMA', 'PIPELINES', 'FILTER', " - "'SOURCE', IDENTIFIER}", + "'GROUPING', 'MEASURE', 'SORT', 'COUNT', 'TYPE', 'SOURCE', " + "'NULL', IDENTIFIER}", "1:24 → mismatched input '{' expecting 'RELATION'", "1:9 → Unrecognized relation type: notyperelation", }), diff --git a/src/substrait/type/tests/TypeTest.cpp b/src/substrait/type/tests/TypeTest.cpp index 0f1c2c2d..f7de5859 100644 --- a/src/substrait/type/tests/TypeTest.cpp +++ b/src/substrait/type/tests/TypeTest.cpp @@ -2,6 +2,7 @@ #include +#include #include "substrait/type/Type.h" using namespace io::substrait; @@ -147,6 +148,14 @@ TEST_F(TypeTest, decodeTest) { ASSERT_EQ(typePtr->scale(), 2); }); + testDecode( + "decimal?<18,2>", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "dec<18,2>"); + ASSERT_EQ(typePtr->precision(), 18); + ASSERT_EQ(typePtr->scale(), 2); + ASSERT_TRUE(typePtr->nullable()); + }); + testDecode( "struct", [](const std::shared_ptr& typePtr) {