diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp index b39b3464..bcbacfc3 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -973,6 +973,44 @@ std::any BasePlanProtoVisitor::visitCrossRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitReferenceRelation( + const ::substrait::proto::ReferenceRel& relation) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitWriteRelation( + const ::substrait::proto::WriteRel& relation) { + // TODO -- Add support for write_type. + if (relation.has_table_schema()) { + visitNamedStruct(relation.table_schema()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitDdlRelation( + const ::substrait::proto::DdlRel& relation) { + // TODO -- Add support for write_type. + if (relation.has_table_schema()) { + visitNamedStruct(relation.table_schema()); + } + if (relation.has_table_defaults()) { + visitExpressionLiteralStruct(relation.table_defaults()); + } + if (relation.has_view_definition()) { + visitRelation(relation.view_definition()); + } + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitHashJoinRelation( const ::substrait::proto::HashJoinRel& relation) { if (relation.has_common()) { @@ -1025,6 +1063,26 @@ std::any BasePlanProtoVisitor::visitMergeJoinRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitNestedLoopJoinRelation( + const ::substrait::proto::NestedLoopJoinRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_left()) { + visitRelation(relation.left()); + } + if (relation.has_right()) { + visitRelation(relation.right()); + } + if (relation.has_expression()) { + visitExpression(relation.expression()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitWindowRelation( const ::substrait::proto::ConsistentPartitionWindowRel& relation) { if (relation.has_common()) { @@ -1103,10 +1161,18 @@ std::any BasePlanProtoVisitor::visitRelation( return visitExtensionLeafRelation(relation.extension_leaf()); case ::substrait::proto::Rel::RelTypeCase::kCross: return visitCrossRelation(relation.cross()); + case ::substrait::proto::Rel::RelTypeCase::kReference: + return visitReferenceRelation(relation.reference()); + case ::substrait::proto::Rel::RelTypeCase::kWrite: + return visitWriteRelation(relation.write()); + case ::substrait::proto::Rel::RelTypeCase::kDdl: + return visitDdlRelation(relation.ddl()); case ::substrait::proto::Rel::RelTypeCase::kHashJoin: return visitHashJoinRelation(relation.hash_join()); case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: return visitMergeJoinRelation(relation.merge_join()); + case ::substrait::proto::Rel::RelTypeCase::kNestedLoopJoin: + return visitNestedLoopJoinRelation(relation.nested_loop_join()); case ::substrait::proto::Rel::kWindow: return visitWindowRelation(relation.window()); case ::substrait::proto::Rel::kExchange: diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h index 0eef3e00..04a20996 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.h +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -170,10 +170,19 @@ class BasePlanProtoVisitor { const ::substrait::proto::ExtensionLeafRel& relation); virtual std::any visitCrossRelation( const ::substrait::proto::CrossRel& relation); + + virtual std::any visitReferenceRelation( + const ::substrait::proto::ReferenceRel& relation); + virtual std::any visitWriteRelation( + const ::substrait::proto::WriteRel& relation); + virtual std::any visitDdlRelation(const ::substrait::proto::DdlRel& relation); + virtual std::any visitHashJoinRelation( const ::substrait::proto::HashJoinRel& relation); virtual std::any visitMergeJoinRelation( const ::substrait::proto::MergeJoinRel& relation); + virtual std::any visitNestedLoopJoinRelation( + const ::substrait::proto::NestedLoopJoinRel& relation); virtual std::any visitWindowRelation( const ::substrait::proto::ConsistentPartitionWindowRel& relation); virtual std::any visitExchangeRelation( diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index 2b783996..d044acc6 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -68,6 +68,14 @@ void eraseInputs(::substrait::proto::Rel* relation) { relation->mutable_cross()->clear_left(); relation->mutable_cross()->clear_right(); break; + case ::substrait::proto::Rel::kReference: + break; + case ::substrait::proto::Rel::kWrite: + relation->mutable_write()->clear_input(); + break; + case ::substrait::proto::Rel::kDdl: + relation->mutable_ddl()->clear_view_definition(); + break; case ::substrait::proto::Rel::kHashJoin: relation->mutable_hash_join()->clear_left(); relation->mutable_hash_join()->clear_right(); @@ -76,6 +84,10 @@ void eraseInputs(::substrait::proto::Rel* relation) { relation->mutable_merge_join()->clear_left(); relation->mutable_merge_join()->clear_right(); break; + case ::substrait::proto::Rel::kNestedLoopJoin: + relation->mutable_nested_loop_join()->clear_left(); + relation->mutable_nested_loop_join()->clear_right(); + break; case ::substrait::proto::Rel::kWindow: relation->mutable_window()->clear_input(); break; @@ -117,10 +129,19 @@ ::google::protobuf::RepeatedField getOutputMapping( return relation.extension_leaf().common().emit().output_mapping(); case ::substrait::proto::Rel::kCross: return relation.cross().common().emit().output_mapping(); + case ::substrait::proto::Rel::kReference: + // There is no common message in a ReferenceRel. + break; + case ::substrait::proto::Rel::kWrite: + return relation.write().common().emit().output_mapping(); + case ::substrait::proto::Rel::kDdl: + return relation.ddl().common().emit().output_mapping(); case ::substrait::proto::Rel::kHashJoin: return relation.hash_join().common().emit().output_mapping(); case ::substrait::proto::Rel::kMergeJoin: return relation.merge_join().common().emit().output_mapping(); + case ::substrait::proto::Rel::kNestedLoopJoin: + return relation.nested_loop_join().common().emit().output_mapping(); case ::substrait::proto::Rel::kWindow: return relation.window().common().emit().output_mapping(); case ::substrait::proto::Rel::kExchange: @@ -566,6 +587,14 @@ void InitialPlanProtoVisitor::updateLocalSchema( addFieldsToRelation( relationData, relation.cross().left(), relation.cross().right()); break; + case ::substrait::proto::Rel::kReference: + break; + case ::substrait::proto::Rel::kWrite: + addFieldsToRelation(relationData, relation.write().input()); + break; + case ::substrait::proto::Rel::kDdl: + addFieldsToRelation(relationData, relation.ddl().view_definition()); + break; case ::substrait::proto::Rel::RelTypeCase::kHashJoin: addFieldsToRelation( relationData, @@ -578,6 +607,12 @@ void InitialPlanProtoVisitor::updateLocalSchema( relation.merge_join().left(), relation.merge_join().right()); break; + case ::substrait::proto::Rel::kNestedLoopJoin: + addFieldsToRelation( + relationData, + relation.nested_loop_join().left(), + relation.nested_loop_join().right()); + break; case ::substrait::proto::Rel::kWindow: addFieldsToRelation(relationData, relation.window().input()); break; diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index 8ffd1d6e..979f79a2 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -165,6 +165,17 @@ std::any PipelineVisitor::visitRelation( relationData->newPipelines.push_back(rightSymbol); break; } + case ::substrait::proto::Rel::kNestedLoopJoin: { + const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.nested_loop_join().left()), + SymbolType::kRelation); + const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.nested_loop_join().right()), + SymbolType::kRelation); + relationData->newPipelines.push_back(leftSymbol); + relationData->newPipelines.push_back(rightSymbol); + break; + } case ::substrait::proto::Rel::kWindow: { const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( PROTO_LOCATION(relation.window().input()), SymbolType::kRelation); @@ -183,6 +194,16 @@ std::any PipelineVisitor::visitRelation( relationData->continuingPipeline = inputSymbol; break; } + case ::substrait::proto::Rel::kReference: + // TODO -- Add support for references in text plans. + break; + case ::substrait::proto::Rel::kWrite: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.write().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kDdl: case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } diff --git a/third_party/substrait b/third_party/substrait index 31b99906..47344783 160000 --- a/third_party/substrait +++ b/third_party/substrait @@ -1 +1 @@ -Subproject commit 31b999060a6e014717f9ae3e6716986ad3066aaf +Subproject commit 47344783dce74645dcb636cb646cd3628df37ef0