From 7942c8e1fda0f001ca3c49e74f19530d4e3480f5 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 14 Feb 2024 00:41:36 -0800 Subject: [PATCH] Add support for new relation types. --- .../converter/BasePlanProtoVisitor.cpp | 66 +++++++++++++++++++ .../textplan/converter/BasePlanProtoVisitor.h | 10 +++ .../textplan/converter/PipelineVisitor.cpp | 21 ++++++ 3 files changed, 97 insertions(+) diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp index b39b3464..bcbacfc3 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.cpp @@ -973,6 +973,44 @@ std::any BasePlanProtoVisitor::visitCrossRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitReferenceRelation( + const ::substrait::proto::ReferenceRel& relation) { + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitWriteRelation( + const ::substrait::proto::WriteRel& relation) { + // TODO -- Add support for write_type. + if (relation.has_table_schema()) { + visitNamedStruct(relation.table_schema()); + } + if (relation.has_input()) { + visitRelation(relation.input()); + } + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + return std::nullopt; +} + +std::any BasePlanProtoVisitor::visitDdlRelation( + const ::substrait::proto::DdlRel& relation) { + // TODO -- Add support for write_type. + if (relation.has_table_schema()) { + visitNamedStruct(relation.table_schema()); + } + if (relation.has_table_defaults()) { + visitExpressionLiteralStruct(relation.table_defaults()); + } + if (relation.has_view_definition()) { + visitRelation(relation.view_definition()); + } + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitHashJoinRelation( const ::substrait::proto::HashJoinRel& relation) { if (relation.has_common()) { @@ -1025,6 +1063,26 @@ std::any BasePlanProtoVisitor::visitMergeJoinRelation( return std::nullopt; } +std::any BasePlanProtoVisitor::visitNestedLoopJoinRelation( + const ::substrait::proto::NestedLoopJoinRel& relation) { + if (relation.has_common()) { + visitRelationCommon(relation.common()); + } + if (relation.has_left()) { + visitRelation(relation.left()); + } + if (relation.has_right()) { + visitRelation(relation.right()); + } + if (relation.has_expression()) { + visitExpression(relation.expression()); + } + if (relation.has_advanced_extension()) { + visitAdvancedExtension(relation.advanced_extension()); + }; + return std::nullopt; +} + std::any BasePlanProtoVisitor::visitWindowRelation( const ::substrait::proto::ConsistentPartitionWindowRel& relation) { if (relation.has_common()) { @@ -1103,10 +1161,18 @@ std::any BasePlanProtoVisitor::visitRelation( return visitExtensionLeafRelation(relation.extension_leaf()); case ::substrait::proto::Rel::RelTypeCase::kCross: return visitCrossRelation(relation.cross()); + case ::substrait::proto::Rel::RelTypeCase::kReference: + return visitReferenceRelation(relation.reference()); + case ::substrait::proto::Rel::RelTypeCase::kWrite: + return visitWriteRelation(relation.write()); + case ::substrait::proto::Rel::RelTypeCase::kDdl: + return visitDdlRelation(relation.ddl()); case ::substrait::proto::Rel::RelTypeCase::kHashJoin: return visitHashJoinRelation(relation.hash_join()); case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: return visitMergeJoinRelation(relation.merge_join()); + case ::substrait::proto::Rel::RelTypeCase::kNestedLoopJoin: + return visitNestedLoopJoinRelation(relation.nested_loop_join()); case ::substrait::proto::Rel::kWindow: return visitWindowRelation(relation.window()); case ::substrait::proto::Rel::kExchange: diff --git a/src/substrait/textplan/converter/BasePlanProtoVisitor.h b/src/substrait/textplan/converter/BasePlanProtoVisitor.h index 0eef3e00..f78f3501 100644 --- a/src/substrait/textplan/converter/BasePlanProtoVisitor.h +++ b/src/substrait/textplan/converter/BasePlanProtoVisitor.h @@ -170,10 +170,20 @@ class BasePlanProtoVisitor { const ::substrait::proto::ExtensionLeafRel& relation); virtual std::any visitCrossRelation( const ::substrait::proto::CrossRel& relation); + + virtual std::any visitReferenceRelation( + const ::substrait::proto::ReferenceRel& relation); + virtual std::any visitWriteRelation( + const ::substrait::proto::WriteRel& relation); + virtual std::any visitDdlRelation( + const ::substrait::proto::DdlRel& relation); + virtual std::any visitHashJoinRelation( const ::substrait::proto::HashJoinRel& relation); virtual std::any visitMergeJoinRelation( const ::substrait::proto::MergeJoinRel& relation); + virtual std::any visitNestedLoopJoinRelation( + const ::substrait::proto::NestedLoopJoinRel& relation); virtual std::any visitWindowRelation( const ::substrait::proto::ConsistentPartitionWindowRel& relation); virtual std::any visitExchangeRelation( diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index 8ffd1d6e..979f79a2 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -165,6 +165,17 @@ std::any PipelineVisitor::visitRelation( relationData->newPipelines.push_back(rightSymbol); break; } + case ::substrait::proto::Rel::kNestedLoopJoin: { + const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.nested_loop_join().left()), + SymbolType::kRelation); + const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.nested_loop_join().right()), + SymbolType::kRelation); + relationData->newPipelines.push_back(leftSymbol); + relationData->newPipelines.push_back(rightSymbol); + break; + } case ::substrait::proto::Rel::kWindow: { const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( PROTO_LOCATION(relation.window().input()), SymbolType::kRelation); @@ -183,6 +194,16 @@ std::any PipelineVisitor::visitRelation( relationData->continuingPipeline = inputSymbol; break; } + case ::substrait::proto::Rel::kReference: + // TODO -- Add support for references in text plans. + break; + case ::substrait::proto::Rel::kWrite: { + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.write().input()), SymbolType::kRelation); + relationData->continuingPipeline = inputSymbol; + break; + } + case ::substrait::proto::Rel::kDdl: case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; }