diff --git a/src/substrait/textplan/Location.cpp b/src/substrait/textplan/Location.cpp index 7c2fc751..6b52228d 100644 --- a/src/substrait/textplan/Location.cpp +++ b/src/substrait/textplan/Location.cpp @@ -52,7 +52,7 @@ bool std::less<::io::substrait::textplan::Location>::operator()( const ::io::substrait::textplan::Location& rhs) const noexcept { if (std::holds_alternative(lhs.loc_)) { if (!std::holds_alternative(rhs.loc_)) { - // This alternative is always less than the remaining choices. + // This alternative is always less than the other location types. return true; } return std::get(lhs.loc_) < @@ -60,8 +60,8 @@ bool std::less<::io::substrait::textplan::Location>::operator()( } else if (std::holds_alternative( lhs.loc_)) { if (!std::holds_alternative(rhs.loc_)) { - // This alternative is always less than the remaining (zero) choices. - return true; + // This alternative is always more than the other location types. + return false; } return std::get(lhs.loc_) < std::get(rhs.loc_); diff --git a/src/substrait/textplan/Location.h b/src/substrait/textplan/Location.h index d9174652..e072eb82 100644 --- a/src/substrait/textplan/Location.h +++ b/src/substrait/textplan/Location.h @@ -31,6 +31,10 @@ class Location { protected: friend bool operator==(const Location& c1, const Location& c2); + friend bool operator!=(const Location& c1, const Location& c2) { + return !(c1 == c2); + } + private: friend std::hash; friend std::less; diff --git a/src/substrait/textplan/StructuredSymbolData.h b/src/substrait/textplan/StructuredSymbolData.h index 6532a17b..039db7dc 100644 --- a/src/substrait/textplan/StructuredSymbolData.h +++ b/src/substrait/textplan/StructuredSymbolData.h @@ -6,7 +6,6 @@ #include #include "substrait/proto/algebra.pb.h" -#include "substrait/textplan/Location.h" namespace io::substrait::textplan { diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index 5abbd890..e97b2d81 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "substrait/textplan/converter/LoadBinary.h" #include "substrait/textplan/converter/ParseBinary.h" @@ -9,6 +10,9 @@ namespace io::substrait::textplan { +using ::protobuf_matchers::EqualsProto; +using ::protobuf_matchers::Partially; +using ::testing::AllOf; using ::testing::Eq; namespace { @@ -158,7 +162,13 @@ std::vector getTestCases() { items = [ {uri_file: "/mock_lineitem.orc" start: 0 length: 3719 orc: {}} ] - })"))), + })")), + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( + R"(relations { root { input { read { + local_files + { items { uri_file: "/mock_lineitem.orc" length: 3719 orc { } } } + } + } } })"))), }, { "read named table", diff --git a/src/substrait/textplan/converter/tests/CMakeLists.txt b/src/substrait/textplan/converter/tests/CMakeLists.txt index a88cd268..e120fe5f 100644 --- a/src/substrait/textplan/converter/tests/CMakeLists.txt +++ b/src/substrait/textplan/converter/tests/CMakeLists.txt @@ -8,6 +8,7 @@ add_test_case( substrait_textplan_converter substrait_common parse_result_matchers + protobuf-matchers gmock gtest gtest_main) diff --git a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp index 473f1cd5..7cdbe3d2 100644 --- a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp @@ -36,7 +36,7 @@ bool continuingPipelineContains( void SubstraitPlanPipelineVisitor::updateRelationSymbol( SubstraitPlanParser::PipelineContext* ctx, const std::string& relationName) { - const auto& symbol = symbolTable_->lookupSymbolByName(relationName); + const auto* symbol = symbolTable_->lookupSymbolByName(relationName); if (symbol == nullptr) { // This is a reference to a missing relation so we create a stub for it. auto relationData = std::make_shared(); @@ -78,7 +78,7 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( } // Refetch our symbol table entry to make sure we have the latest version. - auto symbol = symbolTable_->lookupSymbolByName(relationName); + auto* symbol = symbolTable_->lookupSymbolByName(relationName); auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); // Check for accidental cross-pipeline use. diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index 37069261..794ba660 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -3,6 +3,7 @@ #include "substrait/textplan/parser/SubstraitPlanRelationVisitor.h" #include +#include #include #include #include @@ -14,8 +15,10 @@ #include "substrait/proto/algebra.pb.h" #include "substrait/proto/type.pb.h" #include "substrait/textplan/Any.h" +#include "substrait/textplan/Finally.h" #include "substrait/textplan/Location.h" #include "substrait/textplan/StructuredSymbolData.h" +#include "substrait/textplan/SymbolTable.h" #include "substrait/type/Type.h" namespace io::substrait::textplan { @@ -146,85 +149,99 @@ void setNullable(::substrait::proto::Type* type) { } } -} // namespace - -std::any SubstraitPlanRelationVisitor::aggregateResult( - std::any aggregate, - std::any nextResult) { - if (!nextResult.has_value()) { - // No point returning an unspecified result over whatever we already have. - return aggregate; - } - return nextResult; -} - -std::any SubstraitPlanRelationVisitor::visitRelation( - SubstraitPlanParser::RelationContext* ctx) { - // Create the relation before visiting our children, so they can update it. - auto symbol = symbolTable_->lookupSymbolByLocation(Location(ctx)); - if (symbol == SymbolInfo::kUnknown) { - // This error has been previously dealt with thus we can safely skip it. - return defaultResult(); - } - auto relationData = ANY_CAST(std::shared_ptr, symbol.blob); - ::substrait::proto::Rel relation; - - // Validate that we have the right details for our type. - auto relationType = ANY_CAST(RelationType, symbol.subtype); +void setRelationType( + RelationType relationType, + ::substrait::proto::Rel* relation) { switch (relationType) { case RelationType::kRead: - relation.mutable_read()->clear_common(); + relation->mutable_read()->clear_common(); break; case RelationType::kProject: - relation.mutable_project()->clear_common(); + relation->mutable_project()->clear_common(); break; case RelationType::kJoin: - relation.mutable_join()->clear_common(); + relation->mutable_join()->clear_common(); break; case RelationType::kCross: - relation.mutable_cross()->clear_common(); + relation->mutable_cross()->clear_common(); break; case RelationType::kFetch: - relation.mutable_fetch()->clear_common(); + relation->mutable_fetch()->clear_common(); break; case RelationType::kAggregate: - relation.mutable_aggregate()->clear_common(); + relation->mutable_aggregate()->clear_common(); break; case RelationType::kSort: - relation.mutable_sort()->clear_common(); + relation->mutable_sort()->clear_common(); break; case RelationType::kFilter: - relation.mutable_filter()->clear_common(); + relation->mutable_filter()->clear_common(); break; case RelationType::kSet: - relation.mutable_set()->clear_common(); + relation->mutable_set()->clear_common(); break; case RelationType::kExchange: case RelationType::kDdl: case RelationType::kWrite: break; case RelationType::kHashJoin: - relation.mutable_hash_join()->clear_common(); + relation->mutable_hash_join()->clear_common(); break; case RelationType::kMergeJoin: - relation.mutable_merge_join()->clear_common(); + relation->mutable_merge_join()->clear_common(); break; case RelationType::kExtensionLeaf: - relation.mutable_extension_leaf()->clear_common(); + relation->mutable_extension_leaf()->clear_common(); break; case RelationType::kExtensionSingle: - relation.mutable_extension_single()->clear_common(); + relation->mutable_extension_single()->clear_common(); break; case RelationType::kExtensionMulti: - relation.mutable_extension_multi()->clear_common(); + relation->mutable_extension_multi()->clear_common(); break; case RelationType::kUnknown: break; } +} + +} // namespace + +std::any SubstraitPlanRelationVisitor::aggregateResult( + std::any aggregate, + std::any nextResult) { + if (!nextResult.has_value()) { + // No point returning an unspecified result over whatever we already have. + return aggregate; + } + return nextResult; +} + +std::any SubstraitPlanRelationVisitor::visitRelation( + SubstraitPlanParser::RelationContext* ctx) { + // Create the relation before visiting our children, so they can update it. + auto symbol = symbolTable_->lookupSymbolByLocation(Location(ctx)); + if (symbol == SymbolInfo::kUnknown) { + // This error has been previously dealt with thus we can safely skip it. + return defaultResult(); + } + auto relationData = ANY_CAST(std::shared_ptr, symbol.blob); + ::substrait::proto::Rel relation; + + auto relationType = ANY_CAST(RelationType, symbol.subtype); + setRelationType(relationType, &relation); relationData->relation = relation; + symbolTable_->updateLocation(symbol, PROTO_LOCATION(relationData->relation)); - return visitChildren(ctx); + // Mark the current scope for any operations within this relation. + auto previousScope = currentRelationScope_; + auto resetCurrentScope = + finally([&]() { currentRelationScope_ = previousScope; }); + currentRelationScope_ = &symbol; + + visitChildren(ctx); + + return defaultResult(); } std::any SubstraitPlanRelationVisitor::visitRelation_filter_behavior( @@ -346,6 +363,44 @@ std::any SubstraitPlanRelationVisitor::visitRelationFilter( return defaultResult(); } +std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( + SubstraitPlanParser::RelationUsesSchemaContext* ctx) { + auto parentSymbol = symbolTable_->lookupSymbolByLocation( + Location(dynamic_cast(ctx->parent))); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol.blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol.subtype); + + if (parentRelationType == RelationType::kRead) { + auto schemaName = ctx->id()->getText(); + auto* symbol = symbolTable_->lookupSymbolByName(schemaName); + if (symbol != nullptr) { + auto* schema = + parentRelationData->relation.mutable_read()->mutable_base_schema(); + for (const auto& sym : *symbolTable_) { + if (sym.type != SymbolType::kSchemaColumn) { + continue; + } + if (sym.location != symbol->location) { + continue; + } + schema->add_names(sym.name); + auto typeText = ANY_CAST(std::string, sym.blob); + // TODO -- Use the location of the schema item for errors. + auto typeProto = textToTypeProto(ctx->getStart(), typeText); + if (typeProto.kind_case() != ::substrait::proto::Type::KIND_NOT_SET) { + *schema->mutable_struct_()->add_types() = typeProto; + } + } + } + } else { + errorListener_->addError( + ctx->getStart(), + "Schema references are not defined for this kind of relation."); + } + return defaultResult(); +} + std::any SubstraitPlanRelationVisitor::visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) { auto parentSymbol = symbolTable_->lookupSymbolByLocation( @@ -401,6 +456,38 @@ std::any SubstraitPlanRelationVisitor::visitExpression( return defaultResult(); } +std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( + SubstraitPlanParser::RelationSourceReferenceContext* ctx) { + auto parentSymbol = symbolTable_->lookupSymbolByLocation( + Location(dynamic_cast(ctx->parent))); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol.blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol.subtype); + + if (parentRelationType == RelationType::kRead) { + auto sourceName = ctx->source_reference()->id()->getText(); + auto* symbol = symbolTable_->lookupSymbolByName(sourceName); + if (symbol != nullptr) { + auto* source = + parentRelationData->relation.mutable_read()->mutable_named_table(); + for (const auto& sym : *symbolTable_) { + if (sym.type != SymbolType::kSourceDetail) { + continue; + } + if (sym.location != symbol->location) { + continue; + } + source->add_names(sym.name); + } + } + } else { + errorListener_->addError( + ctx->getStart(), + "Source references are not defined for this kind of relation."); + } + return defaultResult(); +} + std::any SubstraitPlanRelationVisitor::visitExpressionFunctionUse( SubstraitPlanParser::ExpressionFunctionUseContext* ctx) { ::substrait::proto::Expression expr; @@ -435,7 +522,26 @@ std::any SubstraitPlanRelationVisitor::visitExpressionConstant( std::any SubstraitPlanRelationVisitor::visitExpressionColumn( SubstraitPlanParser::ExpressionColumnContext* ctx) { + auto relationData = + ANY_CAST(std::shared_ptr, currentRelationScope_->blob); + + std::string symbolName = ctx->getText(); + auto currentFieldNumber = std::find_if( + relationData->fieldReferences.begin(), + relationData->fieldReferences.end(), + [&](auto ref) { return (ref->name == symbolName); }); + ::substrait::proto::Expression expr; + if (currentFieldNumber != relationData->fieldReferences.end()) { + int32_t fieldReference = static_cast( + (currentFieldNumber - relationData->fieldReferences.begin()) & + std::numeric_limits::max()); + expr.mutable_selection() + ->mutable_direct_reference() + ->mutable_struct_field() + ->set_field(fieldReference); + } + visitChildren(ctx); return expr; } @@ -495,26 +601,12 @@ std::any SubstraitPlanRelationVisitor::visitLiteral_specifier( std::any SubstraitPlanRelationVisitor::visitLiteral_basic_type( SubstraitPlanParser::Literal_basic_typeContext* ctx) { - std::shared_ptr decodedType; - try { - decodedType = Type::decode(ctx->getText()); - } catch (...) { - errorListener_->addError(ctx->getStart(), "Failed to decode type."); - return ::substrait::proto::Type{}; - } - return typeToProto(ctx->getStart(), *decodedType); + return textToTypeProto(ctx->getStart(), ctx->getText()); } std::any SubstraitPlanRelationVisitor::visitLiteral_complex_type( SubstraitPlanParser::Literal_complex_typeContext* ctx) { - std::shared_ptr decodedType; - try { - decodedType = Type::decode(ctx->getText()); - } catch (...) { - errorListener_->addError(ctx->getStart(), "Failed to decode type."); - return ::substrait::proto::Type{}; - } - return typeToProto(ctx->getStart(), *decodedType); + return textToTypeProto(ctx->getStart(), ctx->getText()); } std::any SubstraitPlanRelationVisitor::visitMap_literal( @@ -1151,6 +1243,19 @@ ::substrait::proto::Expression_Literal SubstraitPlanRelationVisitor::visitTime( return literal; } +::substrait::proto::Type SubstraitPlanRelationVisitor::textToTypeProto( + const antlr4::Token* token, + const std::string& typeText) { + std::shared_ptr decodedType; + try { + decodedType = Type::decode(typeText); + } catch (...) { + errorListener_->addError(token, "Failed to decode type."); + return ::substrait::proto::Type{}; + } + return typeToProto(token, *decodedType); +} + ::substrait::proto::Type SubstraitPlanRelationVisitor::typeToProto( const antlr4::Token* token, const ParameterizedType& decodedType) { diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h index b87d3df1..17c78d66 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h @@ -45,9 +45,15 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { std::any visitRelationFilter( SubstraitPlanParser::RelationFilterContext* ctx) override; + std::any visitRelationUsesSchema( + SubstraitPlanParser::RelationUsesSchemaContext* ctx) override; + std::any visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) override; + std::any visitRelationSourceReference( + SubstraitPlanParser::RelationSourceReferenceContext* ctx) override; + // visitExpression is a new method delegating to the methods below. std::any visitExpression(SubstraitPlanParser::ExpressionContext* ctx); @@ -146,12 +152,18 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanParserBaseVisitor { const antlr4::tree::TerminalNode* node, const std::string& str); + ::substrait::proto::Type textToTypeProto( + const antlr4::Token* token, + const std::string& typeText); + ::substrait::proto::Type typeToProto( const antlr4::Token* token, const ParameterizedType& decodedType); std::shared_ptr symbolTable_; std::shared_ptr errorListener_; + + const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. }; } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp index 836f8bb7..30dfb3a6 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp @@ -6,6 +6,7 @@ #include "SubstraitPlanParser/SubstraitPlanParser.h" #include "substrait/textplan/Any.h" +#include "substrait/textplan/Finally.h" #include "substrait/textplan/Location.h" #include "substrait/textplan/StructuredSymbolData.h" @@ -152,15 +153,22 @@ std::any SubstraitPlanVisitor::visitRelation( errorListener_->addError( ctx->getStart(), "Relation named " + relationName + " already defined."); + } else { + auto relationData = std::make_shared(); + symbol = symbolTable_->defineSymbol( + relationName, + Location(ctx), + SymbolType::kRelation, + relType, + relationData); } - auto relationData = std::make_shared(); - symbolTable_->defineSymbol( - relationName, - Location(ctx), - SymbolType::kRelation, - relType, - relationData); + // Mark the current scope for any operations within this relation. + auto previousScope = currentRelationScope_; + auto resetCurrentScope = + finally([&]() { currentRelationScope_ = previousScope; }); + currentRelationScope_ = symbol; + visitRelation_ref(ctx->relation_ref()); for (const auto detail : ctx->relation_detail()) { visit(detail); @@ -267,6 +275,19 @@ std::any SubstraitPlanVisitor::visitExpressionConstant( std::any SubstraitPlanVisitor::visitExpressionColumn( SubstraitPlanParser::ExpressionColumnContext* ctx) { + auto relationData = + ANY_CAST(std::shared_ptr, currentRelationScope_->blob); + std::string column_name = ctx->column_name()->getText(); + auto symbol = symbolTable_->lookupSymbolByName(column_name); + if (symbol == nullptr) { + symbol = symbolTable_->defineSymbol( + column_name, + Location(ctx), + SymbolType::kField, + std::nullopt, + std::nullopt); + relationData->fieldReferences.push_back(symbol); + } return visitChildren(ctx); } @@ -295,11 +316,6 @@ std::any SubstraitPlanVisitor::visitRelationFilter( return visitChildren(ctx); } -std::any SubstraitPlanVisitor::visitRelationProjection( - SubstraitPlanParser::RelationProjectionContext* ctx) { - return visitChildren(ctx); -} - std::any SubstraitPlanVisitor::visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) { return visitChildren(ctx); @@ -335,10 +351,10 @@ std::any SubstraitPlanVisitor::visitLocal_files_detail( for (const auto& f : ctx->file()) { symbolTable_->defineSymbol( f->getText(), - Location(ctx), + PARSER_LOCATION(ctx->parent->parent), // The source we belong to. SymbolType::kSourceDetail, defaultResult(), - ctx->parent->parent); + defaultResult()); } return nullptr; } @@ -373,10 +389,10 @@ std::any SubstraitPlanVisitor::visitNamed_table_detail( std::string str = s->getText(); symbolTable_->defineSymbol( extractFromString(str), - Location(ctx), + PARSER_LOCATION(ctx->parent->parent), // The source we belong to. SymbolType::kSourceDetail, defaultResult(), - ctx->parent->parent); + defaultResult()); } return nullptr; } diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.h b/src/substrait/textplan/parser/SubstraitPlanVisitor.h index 488b5113..c973024d 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.h @@ -81,8 +81,6 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { SubstraitPlanParser::Relation_filter_behaviorContext* ctx) override; std::any visitRelationFilter( SubstraitPlanParser::RelationFilterContext* ctx) override; - std::any visitRelationProjection( - SubstraitPlanParser::RelationProjectionContext* ctx) override; std::any visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) override; std::any visitRelationAdvancedExtension( @@ -114,6 +112,8 @@ class SubstraitPlanVisitor : public SubstraitPlanParserVisitor { std::shared_ptr symbolTable_; std::shared_ptr errorListener_; + const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. + int numFunctionsSeen_{0}; }; diff --git a/src/substrait/textplan/parser/data/provided_sample1.splan b/src/substrait/textplan/parser/data/provided_sample1.splan index a4705e82..80b78216 100644 --- a/src/substrait/textplan/parser/data/provided_sample1.splan +++ b/src/substrait/textplan/parser/data/provided_sample1.splan @@ -17,9 +17,9 @@ read relation read { } schema schema { - r_regionkey UNKNOWN; - r_name UNKNOWN; - r_comment UNKNOWN; + r_regionkey i32; + r_name string; + r_comment string; } source named_table named { diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 index 890883fe..851029ee 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 @@ -57,13 +57,12 @@ relation_filter_behavior ; relation_detail - : COMMON SEMICOLON # relationCommon - | BASE_SCHEMA id SEMICOLON # relationUsesSchema + : COMMON SEMICOLON # relationCommon + | BASE_SCHEMA id SEMICOLON # relationUsesSchema | relation_filter_behavior? FILTER expression SEMICOLON # relationFilter - | PROJECTION SEMICOLON # relationProjection - | EXPRESSION expression SEMICOLON # relationExpression - | ADVANCED_EXTENSION SEMICOLON # relationAdvancedExtension - | source_reference SEMICOLON # relationSourceReference + | EXPRESSION expression SEMICOLON (AS id)? # relationExpression + | ADVANCED_EXTENSION SEMICOLON # relationAdvancedExtension + | source_reference SEMICOLON # relationSourceReference ; expression diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index ac4cc026..901753c4 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -13,6 +13,7 @@ #include "substrait/textplan/tests/ParseResultMatchers.h" using ::protobuf_matchers::EqualsProto; +using ::protobuf_matchers::IgnoringFieldPaths; using ::protobuf_matchers::Partially; using ::testing::AllOf; @@ -217,19 +218,62 @@ std::vector getTestCases() { HasSymbolsWithTypes({"myproject"}, {SymbolType::kRelation}), WhenSerialized(EqSquashingWhitespace( R"(project relation myproject { - expression EXPR-NOT-YET-IMPLEMENTED; - expression EXPR-NOT-YET-IMPLEMENTED; - expression EXPR-NOT-YET-IMPLEMENTED; - expression add(EXPR-NOT-YET-IMPLEMENTED, 1_i8); - expression subtract(EXPR-NOT-YET-IMPLEMENTED, 1_i8); - expression concat(EXPR-NOT-YET-IMPLEMENTED, EXPR-NOT-YET-IMPLEMENTED); + expression r_regionkey; + expression r_name; + expression r_comment; + expression add(r_regionkey, 1_i8); + expression subtract(r_regionkey, 1_i8); + expression concat(r_name, r_name); } extension_space blah.yaml { function add:i8 as add; function subtract:i8 as subtract; function concat:str as concat; - })"))), + })")), + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( + R"(relations { root { input { project { + expressions { + selection { + direct_reference { + struct_field { + field: 0 + } + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 1 + } + } + } + } + expressions { + selection { + direct_reference { + struct_field { + field: 2 + } + } + } + } + expressions { scalar_function { + function_reference: 1 arguments { value { selection { + direct_reference { struct_field { } } } } } + arguments { value { literal { i8: 1 } } } } } + expressions { scalar_function { + function_reference: 2 arguments { value { selection { + direct_reference { struct_field { } } } } } + arguments { value { literal { i8: 1 } } } } } + expressions { scalar_function { + function_reference: 3 arguments { value { selection { + direct_reference { struct_field { field: 1 } } } } } + arguments { value { selection { direct_reference { + struct_field { field: 1 } } } } } } } + } } } })"))), }, { "test6-read-relation", @@ -703,8 +747,6 @@ std::vector getTestCases() { R"(project relation literalexamples { expression 123_i8 AS i32; expression 123_i8 AS i32 AS i64; - // TODO -- Add casts of non-constant types when supported. - // expression r_address AS fixedchar<23>; })", AllOf( HasErrors({}), @@ -773,6 +815,7 @@ std::vector getTestCases() { type UNSPECIFIED; expression order_id; } + schema schema { order_id i32; product_id i32; @@ -806,8 +849,12 @@ std::vector getTestCases() { "#3", ] })", - AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( - R"(relations: { + AsBinaryPlan(IgnoringFieldPaths( + {"relations[0].root.input.join.left.join.expression.selection.direct_reference.struct_field.field", + "relations[0].root.input.join.left.join.post_join_filter.selection.direct_reference.struct_field.field", + "relations[0].root.input.join.expression.selection.direct_reference.struct_field.field"}, + EqualsProto<::substrait::proto::Plan>( + R"(relations: { root: { input: { join: { @@ -815,28 +862,76 @@ std::vector getTestCases() { join: { left: { read: { + base_schema { + names: "order_id" + names: "product_id" + names: "count" + struct { + types { i32 { } } + types { i32 { } } + types { i64 { } } } + } + named_table { names: "#1" } } } right: { read: { + base_schema { + names: "product_id" + names: "cost" + struct { + types { i32 { } } + types { fp32 { } } } + } + named_table { names: "#2" } } } expression: { + selection: { + direct_reference: { + struct_field: { + field: 1 + } + } + } } post_join_filter: { + selection: { + direct_reference: { + struct_field: { + field: 2 + } + } + } } } } right: { read: { + base_schema { + names: "company" + names: "order_id" + struct { + types { string { } } + types { i32 { } } + } } - } - expression: { + named_table { names: "#3" } } } + expression: { + selection: { + direct_reference: { + struct_field: { + field: 6 + } + } + } + } + } } } - })")), + })"))), }, { "test15-relation-without-type",