Skip to content

Commit

Permalink
Add support for new relation types.
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime committed Feb 14, 2024
1 parent bd35324 commit 7942c8e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/substrait/textplan/converter/BasePlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,44 @@ std::any BasePlanProtoVisitor::visitCrossRelation(
return std::nullopt;
}

std::any BasePlanProtoVisitor::visitReferenceRelation(
const ::substrait::proto::ReferenceRel& relation) {
return std::nullopt;
}

std::any BasePlanProtoVisitor::visitWriteRelation(
const ::substrait::proto::WriteRel& relation) {
// TODO -- Add support for write_type.
if (relation.has_table_schema()) {
visitNamedStruct(relation.table_schema());
}
if (relation.has_input()) {
visitRelation(relation.input());
}
if (relation.has_common()) {
visitRelationCommon(relation.common());
}
return std::nullopt;
}

std::any BasePlanProtoVisitor::visitDdlRelation(
const ::substrait::proto::DdlRel& relation) {
// TODO -- Add support for write_type.
if (relation.has_table_schema()) {
visitNamedStruct(relation.table_schema());
}
if (relation.has_table_defaults()) {
visitExpressionLiteralStruct(relation.table_defaults());
}
if (relation.has_view_definition()) {
visitRelation(relation.view_definition());
}
if (relation.has_common()) {
visitRelationCommon(relation.common());
}
return std::nullopt;
}

std::any BasePlanProtoVisitor::visitHashJoinRelation(
const ::substrait::proto::HashJoinRel& relation) {
if (relation.has_common()) {
Expand Down Expand Up @@ -1025,6 +1063,26 @@ std::any BasePlanProtoVisitor::visitMergeJoinRelation(
return std::nullopt;
}

std::any BasePlanProtoVisitor::visitNestedLoopJoinRelation(
const ::substrait::proto::NestedLoopJoinRel& relation) {
if (relation.has_common()) {
visitRelationCommon(relation.common());
}
if (relation.has_left()) {
visitRelation(relation.left());
}
if (relation.has_right()) {
visitRelation(relation.right());
}
if (relation.has_expression()) {
visitExpression(relation.expression());
}
if (relation.has_advanced_extension()) {
visitAdvancedExtension(relation.advanced_extension());
};
return std::nullopt;
}

std::any BasePlanProtoVisitor::visitWindowRelation(
const ::substrait::proto::ConsistentPartitionWindowRel& relation) {
if (relation.has_common()) {
Expand Down Expand Up @@ -1103,10 +1161,18 @@ std::any BasePlanProtoVisitor::visitRelation(
return visitExtensionLeafRelation(relation.extension_leaf());
case ::substrait::proto::Rel::RelTypeCase::kCross:
return visitCrossRelation(relation.cross());
case ::substrait::proto::Rel::RelTypeCase::kReference:
return visitReferenceRelation(relation.reference());
case ::substrait::proto::Rel::RelTypeCase::kWrite:
return visitWriteRelation(relation.write());
case ::substrait::proto::Rel::RelTypeCase::kDdl:
return visitDdlRelation(relation.ddl());
case ::substrait::proto::Rel::RelTypeCase::kHashJoin:
return visitHashJoinRelation(relation.hash_join());
case ::substrait::proto::Rel::RelTypeCase::kMergeJoin:
return visitMergeJoinRelation(relation.merge_join());
case ::substrait::proto::Rel::RelTypeCase::kNestedLoopJoin:
return visitNestedLoopJoinRelation(relation.nested_loop_join());
case ::substrait::proto::Rel::kWindow:
return visitWindowRelation(relation.window());
case ::substrait::proto::Rel::kExchange:
Expand Down
10 changes: 10 additions & 0 deletions src/substrait/textplan/converter/BasePlanProtoVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,20 @@ class BasePlanProtoVisitor {
const ::substrait::proto::ExtensionLeafRel& relation);
virtual std::any visitCrossRelation(
const ::substrait::proto::CrossRel& relation);

virtual std::any visitReferenceRelation(
const ::substrait::proto::ReferenceRel& relation);
virtual std::any visitWriteRelation(
const ::substrait::proto::WriteRel& relation);
virtual std::any visitDdlRelation(
const ::substrait::proto::DdlRel& relation);

virtual std::any visitHashJoinRelation(
const ::substrait::proto::HashJoinRel& relation);
virtual std::any visitMergeJoinRelation(
const ::substrait::proto::MergeJoinRel& relation);
virtual std::any visitNestedLoopJoinRelation(
const ::substrait::proto::NestedLoopJoinRel& relation);
virtual std::any visitWindowRelation(
const ::substrait::proto::ConsistentPartitionWindowRel& relation);
virtual std::any visitExchangeRelation(
Expand Down
21 changes: 21 additions & 0 deletions src/substrait/textplan/converter/PipelineVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ std::any PipelineVisitor::visitRelation(
relationData->newPipelines.push_back(rightSymbol);
break;
}
case ::substrait::proto::Rel::kNestedLoopJoin: {
const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType(
PROTO_LOCATION(relation.nested_loop_join().left()),
SymbolType::kRelation);
const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType(
PROTO_LOCATION(relation.nested_loop_join().right()),
SymbolType::kRelation);
relationData->newPipelines.push_back(leftSymbol);
relationData->newPipelines.push_back(rightSymbol);
break;
}
case ::substrait::proto::Rel::kWindow: {
const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType(
PROTO_LOCATION(relation.window().input()), SymbolType::kRelation);
Expand All @@ -183,6 +194,16 @@ std::any PipelineVisitor::visitRelation(
relationData->continuingPipeline = inputSymbol;
break;
}
case ::substrait::proto::Rel::kReference:
// TODO -- Add support for references in text plans.
break;
case ::substrait::proto::Rel::kWrite: {
const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType(
PROTO_LOCATION(relation.write().input()), SymbolType::kRelation);
relationData->continuingPipeline = inputSymbol;
break;
}
case ::substrait::proto::Rel::kDdl:
case ::substrait::proto::Rel::REL_TYPE_NOT_SET:
break;
}
Expand Down

0 comments on commit 7942c8e

Please sign in to comment.