diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp index a5473641..b39b3464 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -411,6 +411,21 @@ std::any BasePlanProtoVisitor::visitWindowFunction( return std::nullopt; } +std::any BasePlanProtoVisitor::visitWindowRelFunction( + const ::substrait::proto::ConsistentPartitionWindowRel::WindowRelFunction& + function) { + for (const auto& arg : function.arguments()) { + visitFunctionArgument(arg); + } + for (const auto& arg : function.options()) { + visitFunctionOption(arg); + } + if (function.has_output_type()) { + visitType(function.output_type()); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen) { for (const auto& ifThenIf : ifthen.ifs()) { @@ -649,7 +664,6 @@ std::any BasePlanProtoVisitor::visitExpression( case ::substrait::proto::Expression::RexTypeCase::REX_TYPE_NOT_SET: break; } - // TODO -- Use an error listener instead. SUBSTRAIT_UNSUPPORTED( "Unsupported expression type encountered: " + std::to_string(expression.rex_type_case())); @@ -736,6 +750,25 @@ std::any BasePlanProtoVisitor::visitFieldReference( return std::nullopt; } +std::any BasePlanProtoVisitor::visitExpandField( + const ::substrait::proto::ExpandRel::ExpandField& field) { + switch (field.field_type_case()) { + case ::substrait::proto::ExpandRel_ExpandField::kSwitchingField: + for (const auto& switchingField : field.switching_field().duplicates()) { + visitExpression(switchingField); + } + break; + case ::substrait::proto::ExpandRel_ExpandField::kConsistentField: + if (field.has_consistent_field()) { + visitExpression(field.consistent_field()); + } + break; + case ::substrait::proto::ExpandRel_ExpandField::FIELD_TYPE_NOT_SET: + break; + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitReadRelation( const ::substrait::proto::ReadRel& relation) { if (relation.has_common()) { @@ -992,6 +1025,57 @@ std::any BasePlanProtoVisitor::visitMergeJoinRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitWindowRelation( + const ::substrait::proto::ConsistentPartitionWindowRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& func : relation.window_functions()) { + visitWindowRelFunction(func); + } + for (const auto& exp : relation.partition_expressions()) { + visitExpression(exp); + } + for (const auto& sort : relation.sorts()) { + visitSortField(sort); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExchangeRelation( + const ::substrait::proto::ExchangeRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitExpandRelation( + const ::substrait::proto::ExpandRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + for (const auto& expandField : relation.fields()) { + visitExpandField(expandField); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitRelation( const ::substrait::proto::Rel& relation) { switch (relation.rel_type_case()) { @@ -1023,6 +1107,12 @@ std::any BasePlanProtoVisitor::visitRelation( return visitHashJoinRelation(relation.hash_join()); case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: return visitMergeJoinRelation(relation.merge_join()); + case ::substrait::proto::Rel::kWindow: + return visitWindowRelation(relation.window()); + case ::substrait::proto::Rel::kExchange: + return visitExchangeRelation(relation.exchange()); + case ::substrait::proto::Rel::kExpand: + return visitExpandRelation(relation.expand()); case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h index af1fb138..0eef3e00 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.h +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -4,6 +4,7 @@ #include +#include "substrait/proto/algebra.pb.h" #include "substrait/proto/plan.pb.h" namespace io::substrait::textplan { @@ -84,6 +85,9 @@ class BasePlanProtoVisitor { const ::substrait::proto::Expression::ScalarFunction& function); virtual std::any visitWindowFunction( const ::substrait::proto::Expression::WindowFunction& function); + virtual std::any visitWindowRelFunction( + const ::substrait::proto::ConsistentPartitionWindowRel::WindowRelFunction& + function); virtual std::any visitIfThen( const ::substrait::proto::Expression::IfThen& ifthen); virtual std::any visitSwitchExpression( @@ -140,6 +144,8 @@ class BasePlanProtoVisitor { virtual std::any visitSortField(const ::substrait::proto::SortField& sort); virtual std::any visitFieldReference( const ::substrait::proto::Expression::FieldReference& ref); + virtual std::any visitExpandField( + const ::substrait::proto::ExpandRel::ExpandField& field); virtual std::any visitReadRelation( const ::substrait::proto::ReadRel& relation); @@ -168,6 +174,12 @@ class BasePlanProtoVisitor { const ::substrait::proto::HashJoinRel& relation); virtual std::any visitMergeJoinRelation( const ::substrait::proto::MergeJoinRel& relation); + virtual std::any visitWindowRelation( + const ::substrait::proto::ConsistentPartitionWindowRel& relation); + virtual std::any visitExchangeRelation( + const ::substrait::proto::ExchangeRel& relation); + virtual std::any visitExpandRelation( + const ::substrait::proto::ExpandRel& relation); virtual std::any visitRelation(const ::substrait::proto::Rel& relation); virtual std::any visitRelationRoot( diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index dfbdf38f..7fd3eb0a 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -76,6 +76,15 @@ void eraseInputs(::substrait::proto::Rel* relation) { relation->mutable_merge_join()->clear_left(); relation->mutable_merge_join()->clear_right(); break; + case ::substrait::proto::Rel::kWindow: + relation->mutable_window()->clear_input(); + break; + case ::substrait::proto::Rel::kExchange: + relation->mutable_exchange()->clear_input(); + break; + case ::substrait::proto::Rel::kExpand: + relation->mutable_expand()->clear_input(); + break; case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } @@ -112,6 +121,12 @@ ::google::protobuf::RepeatedField getOutputMapping( return relation.hash_join().common().emit().output_mapping(); case ::substrait::proto::Rel::kMergeJoin: return relation.merge_join().common().emit().output_mapping(); + case ::substrait::proto::Rel::kWindow: + return relation.window().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExchange: + return relation.exchange().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExpand: + return relation.expand().common().emit().output_mapping(); case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } @@ -521,6 +536,15 @@ void InitialPlanProtoVisitor::updateLocalSchema( relation.merge_join().left(), relation.merge_join().right()); break; + case ::substrait::proto::Rel::kWindow: + addFieldsToRelation(relationData, relation.window().input()); + break; + case ::substrait::proto::Rel::kExchange: + addFieldsToRelation(relationData, relation.exchange().input()); + break; + case ::substrait::proto::Rel::kExpand: + addFieldsToRelation(relationData, relation.expand().input()); + break; case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index d63a6f56..57d5dcea 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -107,6 +107,24 @@ std::any PipelineVisitor::visitRelation( relationData->newPipelines.push_back(rightSymbol); break; } + case ::substrait::proto::Rel::kWindow: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.window().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kExchange: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.exchange().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kExpand: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.expand().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/third_party/substrait b/third_party/substrait index 07e4feb5..31b99906 160000 --- a/third_party/substrait +++ b/third_party/substrait @@ -1 +1 @@ -Subproject commit 07e4feb5983478e7d0d95dc1d9b5e176685dbdc3 +Subproject commit 31b999060a6e014717f9ae3e6716986ad3066aaf