From 0a71a2f22816a1d7cd5db96c8b76aac669e4a612 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 19 Jul 2023 12:22:32 -0700 Subject: [PATCH] feat: add output field mapping (#78) features: * updated the symbol table so that multiple symbols can share the same location * the root relation now contains both a relation symbol and a name structure symbol * moved commonly string search functions into a separate file * added EMIT * added aliases * added join types fixes: * fixed root names sort order --- src/substrait/textplan/CMakeLists.txt | 2 + src/substrait/textplan/StringManipulation.cpp | 19 + src/substrait/textplan/StringManipulation.h | 15 + src/substrait/textplan/StructuredSymbolData.h | 17 + src/substrait/textplan/SymbolTable.cpp | 101 ++- src/substrait/textplan/SymbolTable.h | 25 +- src/substrait/textplan/SymbolTablePrinter.cpp | 14 +- .../converter/InitialPlanProtoVisitor.cpp | 249 +++++++- .../converter/InitialPlanProtoVisitor.h | 11 +- .../textplan/converter/PipelineVisitor.cpp | 93 ++- .../textplan/converter/PlanPrinterVisitor.cpp | 159 ++++- .../textplan/converter/PlanPrinterVisitor.h | 8 +- .../converter/ReferenceNormalizer.cpp | 2 +- .../data/q6_first_stage.golden.splan | 73 +++ .../converter/data/q6_first_stage.json | 6 +- .../tests/BinaryToTextPlanConversionTest.cpp | 20 +- .../textplan/converter/tests/CMakeLists.txt | 6 +- src/substrait/textplan/data/tpch-plan01.json | 2 +- src/substrait/textplan/data/tpch-plan13.json | 2 +- src/substrait/textplan/parser/ParseText.h | 2 + .../parser/SubstraitPlanPipelineVisitor.cpp | 23 +- .../parser/SubstraitPlanRelationVisitor.cpp | 581 ++++++++++++++++-- .../parser/SubstraitPlanRelationVisitor.h | 29 + .../textplan/parser/SubstraitPlanVisitor.cpp | 178 ++++-- .../parser/data/provided_sample1.splan | 10 +- .../parser/grammar/SubstraitPlanLexer.g4 | 2 + .../parser/grammar/SubstraitPlanParser.g4 | 35 +- .../parser/tests/TextPlanParserTest.cpp | 202 ++++-- .../textplan/tests/RoundtripTest.cpp | 19 +- 29 files changed, 1596 insertions(+), 309 deletions(-) create mode 100644 src/substrait/textplan/StringManipulation.cpp create mode 100644 src/substrait/textplan/StringManipulation.h create mode 100644 src/substrait/textplan/converter/data/q6_first_stage.golden.splan diff --git a/src/substrait/textplan/CMakeLists.txt b/src/substrait/textplan/CMakeLists.txt index 4db8adff..1e8b7dbb 100644 --- a/src/substrait/textplan/CMakeLists.txt +++ b/src/substrait/textplan/CMakeLists.txt @@ -7,6 +7,8 @@ add_library( symbol_table Location.cpp Location.h + StringManipulation.cpp + StringManipulation.h SymbolTable.cpp SymbolTable.h SymbolTablePrinter.cpp diff --git a/src/substrait/textplan/StringManipulation.cpp b/src/substrait/textplan/StringManipulation.cpp new file mode 100644 index 00000000..eac3c56a --- /dev/null +++ b/src/substrait/textplan/StringManipulation.cpp @@ -0,0 +1,19 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "StringManipulation.h" + +namespace io::substrait::textplan { + +// Yields true if the string 'haystack' starts with the string 'needle'. +bool startsWith(std::string_view haystack, std::string_view needle) { + return haystack.size() > needle.size() && + haystack.substr(0, needle.size()) == needle; +} + +// Returns true if the string 'haystack' ends with the string 'needle'. +bool endsWith(std::string_view haystack, std::string_view needle) { + return haystack.size() > needle.size() && + haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; +} + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/StringManipulation.h b/src/substrait/textplan/StringManipulation.h new file mode 100644 index 00000000..9c24418f --- /dev/null +++ b/src/substrait/textplan/StringManipulation.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include + +namespace io::substrait::textplan { + +// Yields true if the string 'haystack' starts with the string 'needle'. +bool startsWith(std::string_view haystack, std::string_view needle); + +// Returns true if the string 'haystack' ends with the string 'needle'. +bool endsWith(std::string_view haystack, std::string_view needle); + +} // namespace io::substrait::textplan diff --git a/src/substrait/textplan/StructuredSymbolData.h b/src/substrait/textplan/StructuredSymbolData.h index 039db7dc..676fa211 100644 --- a/src/substrait/textplan/StructuredSymbolData.h +++ b/src/substrait/textplan/StructuredSymbolData.h @@ -37,6 +37,23 @@ struct RelationData { // Column name for each field known to this relation (in field order). Used // to determine what fields are coming in as well and fields are going out. std::vector fieldReferences; + + // Each field reference here was generated within the current relation. + std::vector generatedFieldReferences; + + // Local aliases for field references in this relation. Used to replace the + // normal form symbols would take for this relation's use only. (Later + // references to the symbol would use the alias.) + std::map generatedFieldReferenceAlternativeExpression; + + // If populated, supersedes the combination of fieldReferences and + // generatedFieldReferences for the field symbols exposed by this relation. + std::vector outputFieldReferences; + + // Contains the field reference names seen so far while processing this + // relation along with the id of the first occurrence. Used to detect when + // fully qualified references are necessary. + std::map seenFieldReferenceNames; }; // Used by Schema symbols to keep track of assigned values. diff --git a/src/substrait/textplan/SymbolTable.cpp b/src/substrait/textplan/SymbolTable.cpp index bdab07ff..fdfef0bb 100644 --- a/src/substrait/textplan/SymbolTable.cpp +++ b/src/substrait/textplan/SymbolTable.cpp @@ -8,7 +8,9 @@ #include #include "substrait/common/Exceptions.h" +#include "substrait/textplan/Any.h" #include "substrait/textplan/Location.h" +#include "substrait/textplan/StructuredSymbolData.h" namespace io::substrait::textplan { @@ -123,6 +125,12 @@ void SymbolTable::updateLocation( symbolsByLocation_.insert(std::make_pair(location, index)); } +void SymbolTable::addAlias(const std::string& alias, const SymbolInfo* symbol) { + auto index = findSymbolIndex(*symbol); + symbols_[index]->alias = alias; + symbolsByName_.insert(std::make_pair(alias, index)); +} + const SymbolInfo* SymbolTable::lookupSymbolByName( const std::string& name) const { auto itr = symbolsByName_.find(name); @@ -132,13 +140,33 @@ const SymbolInfo* SymbolTable::lookupSymbolByName( return symbols_[itr->second].get(); } -const SymbolInfo* SymbolTable::lookupSymbolByLocation( +std::vector SymbolTable::lookupSymbolsByLocation( const Location& location) const { - auto itr = symbolsByLocation_.find(location); - if (itr == symbolsByLocation_.end()) { - return nullptr; + std::vector symbols; + auto [begin, end] = symbolsByLocation_.equal_range(location); + for (auto itr = begin; itr != end; ++itr) { + symbols.push_back(symbols_[itr->second].get()); } - return symbols_[itr->second].get(); + return symbols; +} + +const SymbolInfo* SymbolTable::lookupSymbolByLocationAndType( + const Location& location, + SymbolType type) const { + return lookupSymbolByLocationAndTypes(location, {type}); +} + +const SymbolInfo* SymbolTable::lookupSymbolByLocationAndTypes( + const Location& location, + std::unordered_set types) const { + auto [begin, end] = symbolsByLocation_.equal_range(location); + for (auto itr = begin; itr != end; ++itr) { + auto symbol = symbols_[itr->second].get(); + if (types.find(symbol->type) != types.end()) { + return symbol; + } + } + return nullptr; } const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type) @@ -162,4 +190,67 @@ SymbolTableIterator SymbolTable::end() const { return {this, symbols_.size()}; } +std::string SymbolTable::toDebugString() const { + std::stringstream result; + bool textAlreadyWritten = false; + int32_t relationCount = 0; + for (const auto& symbol : symbols_) { + if (symbol->type != SymbolType::kRelation) { + continue; + } + auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); + result << std::left << std::setw(4) << relationCount++; + result << std::left << std::setw(20) << symbol->name << std::endl; + + int32_t fieldNum = 0; + for (const auto& field : relationData->fieldReferences) { + result << " " << std::setw(4) << fieldNum++ << " "; + if (field->schema != nullptr) { + result << field->schema->name << "."; + } + result << field->name; + if (!field->alias.empty()) { + result << " " << field->alias; + } + result << std::endl; + } + + for (const auto& field : relationData->generatedFieldReferences) { + result << " g" << std::setw(4) << fieldNum++ << " "; + if (field->schema != nullptr) { + result << field->schema->name << "."; + } + result << field->name; + if (relationData->generatedFieldReferenceAlternativeExpression.find( + fieldNum) != + relationData->generatedFieldReferenceAlternativeExpression.end()) { + result << " " + << relationData + ->generatedFieldReferenceAlternativeExpression[fieldNum]; + } else if (!field->alias.empty()) { + result << " " << field->alias; + } + result << std::endl; + } + + int32_t outputFieldNum = 0; + for (const auto& field : relationData->outputFieldReferences) { + result << " o" << std::setw(4) << outputFieldNum++ << " "; + if (field->schema != nullptr) { + result << field->schema->name << "."; + } + result << field->name; + if (!field->alias.empty()) { + result << " " << field->alias; + } + result << std::endl; + } + textAlreadyWritten = true; + } + if (textAlreadyWritten) { + result << std::endl; + } + return result.str(); +} + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTable.h b/src/substrait/textplan/SymbolTable.h index a0b4ad58..97e54879 100644 --- a/src/substrait/textplan/SymbolTable.h +++ b/src/substrait/textplan/SymbolTable.h @@ -3,10 +3,11 @@ #pragma once #include +#include #include -#include #include #include +#include #include #include @@ -19,13 +20,14 @@ enum class SymbolType { kFunction = 1, kPlanRelation = 2, kRelation = 3, - kRelationDetail = 4, kSchema = 5, kSchemaColumn = 6, kSource = 7, kSourceDetail = 8, kField = 9, kRoot = 10, + kTable = 11, + kMeasure = 12, kUnknown = -1, }; @@ -75,6 +77,8 @@ const std::string& symbolTypeName(SymbolType type); struct SymbolInfo { std::string name; + std::string alias{}; // If present, use this instead of name. + const SymbolInfo* schema{nullptr}; // The related schema symbol if present. Location location; SymbolType type; std::any subtype; @@ -144,12 +148,23 @@ class SymbolTable { // Changes the location for a specified existing symbol. void updateLocation(const SymbolInfo& symbol, const Location& location); + // Adds an alias to the given symbol. + void addAlias(const std::string& alias, const SymbolInfo* symbol); + [[nodiscard]] const SymbolInfo* lookupSymbolByName( const std::string& name) const; - [[nodiscard]] const SymbolInfo* lookupSymbolByLocation( + [[nodiscard]] std::vector lookupSymbolsByLocation( const Location& location) const; + [[nodiscard]] const SymbolInfo* lookupSymbolByLocationAndType( + const Location& location, + SymbolType type) const; + + [[nodiscard]] const SymbolInfo* lookupSymbolByLocationAndTypes( + const Location& location, + std::unordered_set types) const; + [[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type) const; @@ -177,6 +192,8 @@ class SymbolTable { return os; } + [[nodiscard]] std::string toDebugString() const; + private: // Returns the table size if the symbol is not found. size_t findSymbolIndex(const SymbolInfo& symbol); @@ -187,7 +204,7 @@ class SymbolTable { std::vector> symbols_; std::unordered_map symbolsByName_; - std::unordered_map symbolsByLocation_; + std::multimap symbolsByLocation_; }; } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/SymbolTablePrinter.cpp b/src/substrait/textplan/SymbolTablePrinter.cpp index bfa68709..19f96188 100644 --- a/src/substrait/textplan/SymbolTablePrinter.cpp +++ b/src/substrait/textplan/SymbolTablePrinter.cpp @@ -203,9 +203,7 @@ std::string outputSchemaSection(const SymbolTable& symbolTable) { if (info.type != SymbolType::kSchema) { continue; } - - if (info.blob.type() != typeid(const ::substrait::proto::NamedStruct*)) { - // TODO -- Implement schemas for text plans. + if (!info.blob.has_value()) { continue; } @@ -241,10 +239,6 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) { if (hasPreviousText) { text << "\n"; } - if (info.subtype.type() != typeid(SourceType)) { - // TODO -- Implement sources for text plans. - continue; - } auto subtype = ANY_CAST(SourceType, info.subtype); switch (subtype) { case SourceType::kNamedTable: { @@ -300,6 +294,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { } return spaceNames.at(a) < spaceNames.at(b); }; + // Sorted by name if we have one, otherwise by space id. std::set usedSpaces(cmp); // Look at the existing spaces. @@ -352,8 +347,8 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) { functionsToOutput.emplace_back(info.name, functionData->name); } std::sort(functionsToOutput.begin(), functionsToOutput.end()); - for (const auto& item : functionsToOutput) { - text << " function " << item.second << " as " << item.first << ";\n"; + for (auto [shortName, canonicalName] : functionsToOutput) { + text << " function " << canonicalName << " as " << shortName << ";\n"; } text << "}\n"; hasPreviousOutput = true; @@ -446,6 +441,7 @@ void outputFunctionsToBinaryPlan( } // namespace +// TODO -- Update so that errors occurring during printing are captured. std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) { std::stringstream text; bool hasPreviousText = false; diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp index fc0f493b..dfbdf38f 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp @@ -19,6 +19,7 @@ namespace io::substrait::textplan { namespace { +const std::string kIntermediateNodeName{"intermediate"}; const std::string kRootNames{"root.names"}; std::string shortName(std::string str) { @@ -80,6 +81,45 @@ void eraseInputs(::substrait::proto::Rel* relation) { } } +::google::protobuf::RepeatedField getOutputMapping( + const ::substrait::proto::Rel& relation) { + switch (relation.rel_type_case()) { + case ::substrait::proto::Rel::kRead: + return relation.read().common().emit().output_mapping(); + case ::substrait::proto::Rel::kFilter: + return relation.filter().common().emit().output_mapping(); + case ::substrait::proto::Rel::kFetch: + return relation.fetch().common().emit().output_mapping(); + case ::substrait::proto::Rel::kAggregate: + return relation.aggregate().common().emit().output_mapping(); + case ::substrait::proto::Rel::kSort: + return relation.sort().common().emit().output_mapping(); + case ::substrait::proto::Rel::kJoin: + return relation.join().common().emit().output_mapping(); + case ::substrait::proto::Rel::kProject: + return relation.project().common().emit().output_mapping(); + case ::substrait::proto::Rel::kSet: + return relation.set().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExtensionSingle: + return relation.extension_single().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExtensionMulti: + return relation.extension_multi().common().emit().output_mapping(); + case ::substrait::proto::Rel::kExtensionLeaf: + return relation.extension_leaf().common().emit().output_mapping(); + case ::substrait::proto::Rel::kCross: + return relation.cross().common().emit().output_mapping(); + case ::substrait::proto::Rel::kHashJoin: + return relation.hash_join().common().emit().output_mapping(); + case ::substrait::proto::Rel::kMergeJoin: + return relation.merge_join().common().emit().output_mapping(); + case ::substrait::proto::Rel::REL_TYPE_NOT_SET: + break; + } + + // The compiler will prevent us from reaching here. + return {}; +} + } // namespace std::any InitialPlanProtoVisitor::visitExtension( @@ -146,17 +186,24 @@ std::any InitialPlanProtoVisitor::visitRelation( BasePlanProtoVisitor::visitRelation(relation); - auto uniqueName = symbolTable_->getUniqueName(name); + // Create a reduced copy of the relation for use in the symbol table. auto relationData = std::make_shared(); relationData->relation = relation; eraseInputs(&relationData->relation); - updateLocalSchema(relationData, relation); - if (readRelationSources_.find(&relation) != readRelationSources_.end()) { - relationData->source = readRelationSources_[&relation]; + + // Update the relation data for long term use. + updateLocalSchema(relationData, relation, relationData->relation); + if (readRelationSources_.find(currentRelationScope_) != + readRelationSources_.end()) { + relationData->source = readRelationSources_[currentRelationScope_]; } - if (readRelationSchemas_.find(&relation) != readRelationSchemas_.end()) { - relationData->schema = readRelationSchemas_[&relation]; + if (readRelationSchemas_.find(currentRelationScope_) != + readRelationSchemas_.end()) { + relationData->schema = readRelationSchemas_[currentRelationScope_]; } + + // Finally create our entry in the symbol table. + auto uniqueName = symbolTable_->getUniqueName(name); auto symbol = symbolTable_->defineSymbol( uniqueName, PROTO_LOCATION(relation), @@ -240,9 +287,9 @@ std::any InitialPlanProtoVisitor::visitNamedTable( symbolTable_->defineSymbol( name, Location::kUnknownLocation, - SymbolType::kField, + SymbolType::kTable, SourceType::kUnknown, - &table); // Field names are in this scope. + std::nullopt); } return BasePlanProtoVisitor::visitNamedTable(table); } @@ -262,31 +309,58 @@ std::any InitialPlanProtoVisitor::visitExtensionTable( std::any InitialPlanProtoVisitor::visitNamedStruct( const ::substrait::proto::NamedStruct& named) { + int nameNum = 0; for (const auto& name : named.names()) { if (symbolTable_->lookupSymbolByName(name) != nullptr) { continue; } + ::substrait::proto::Type type; + if (nameNum < named.struct_().types().size()) { + type = named.struct_().types(nameNum); + } else { + errorListener_->addError( + "Type number " + std::to_string(nameNum) + + " requested but there are not that many types in the proto."); + return std::nullopt; + } symbolTable_->defineSymbol( name, PROTO_LOCATION(named), SymbolType::kField, SourceType::kUnknown, - &named); // Field names are in this scope. + type); + nameNum++; } return BasePlanProtoVisitor::visitNamedStruct(named); } +void InitialPlanProtoVisitor::addFieldToRelation( + const std::shared_ptr& relationData, + const SymbolInfo* field) { + relationData->fieldReferences.push_back(field); +} + void InitialPlanProtoVisitor::addFieldsToRelation( const std::shared_ptr& relationData, const ::substrait::proto::Rel& relation) { - auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); + auto* symbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation), SymbolType::kRelation); if (symbol == nullptr || symbol->type != SymbolType::kRelation) { return; } auto symbolRelationData = ANY_CAST(std::shared_ptr, symbol->blob); - for (const auto& field : symbolRelationData->fieldReferences) { - relationData->fieldReferences.push_back(field); + if (!symbolRelationData->outputFieldReferences.empty()) { + for (const auto& field : symbolRelationData->outputFieldReferences) { + addFieldToRelation(relationData, field); + } + } else { + for (const auto& field : symbolRelationData->fieldReferences) { + addFieldToRelation(relationData, field); + } + for (const auto& field : symbolRelationData->generatedFieldReferences) { + addFieldToRelation(relationData, field); + } } } @@ -298,22 +372,29 @@ void InitialPlanProtoVisitor::addFieldsToRelation( addFieldsToRelation(relationData, right); } +std::string getSchemaName(const SymbolInfo* field) { + if (field->schema != nullptr) { + return field->schema->name; + } + return ""; +} + void InitialPlanProtoVisitor::updateLocalSchema( const std::shared_ptr& relationData, - const ::substrait::proto::Rel& relation) { + const ::substrait::proto::Rel& relation, + const ::substrait::proto::Rel& internalRelation) { switch (relation.rel_type_case()) { case ::substrait::proto::Rel::RelTypeCase::kRead: if (relation.read().has_base_schema()) { for (const auto& name : relation.read().base_schema().names()) { - auto* symbol = symbolTable_->lookupSymbolByName(name); - if (symbol == nullptr) { - symbol = symbolTable_->defineSymbol( - name, - PROTO_LOCATION(relation.read().base_schema()), - SymbolType::kField, - std::nullopt, - std::nullopt); - } + auto symbol = symbolTable_->defineSymbol( + name, + PROTO_LOCATION(relation.read().base_schema()), + SymbolType::kField, + SourceType::kUnknown, + std::nullopt); + + symbol->schema = readRelationSchemas_[currentRelationScope_]; relationData->fieldReferences.emplace_back(symbol); } } @@ -326,6 +407,26 @@ void InitialPlanProtoVisitor::updateLocalSchema( break; case ::substrait::proto::Rel::RelTypeCase::kAggregate: addFieldsToRelation(relationData, relation.aggregate().input()); + for (const auto& grouping : relation.aggregate().groupings()) { + addGroupingToRelation(relationData, grouping); + } + for (const auto& measure : internalRelation.aggregate().measures()) { + auto uniqueName = symbolTable_->getUniqueName("measurename"); + auto symbol = symbolTable_->defineSymbol( + uniqueName, + PROTO_LOCATION(measure), + SymbolType::kMeasure, + SourceType::kUnknown, + std::nullopt); + symbol->location = PROTO_LOCATION(measure); + relationData->generatedFieldReferences.emplace_back(symbol); + } + // Aggregate relations are different in that they alter the emitted fields + // by default. + relationData->outputFieldReferences.insert( + relationData->outputFieldReferences.end(), + relationData->generatedFieldReferences.begin(), + relationData->generatedFieldReferences.end()); break; case ::substrait::proto::Rel::RelTypeCase::kSort: addFieldsToRelation(relationData, relation.sort().input()); @@ -334,9 +435,64 @@ void InitialPlanProtoVisitor::updateLocalSchema( addFieldsToRelation( relationData, relation.join().left(), relation.join().right()); break; - case ::substrait::proto::Rel::RelTypeCase::kProject: + case ::substrait::proto::Rel::RelTypeCase::kProject: { addFieldsToRelation(relationData, relation.project().input()); + for (const auto& expr : relation.project().expressions()) { + // TODO -- Add support for other kinds of direct references. + if (expr.selection().has_direct_reference() && + expr.selection().direct_reference().has_struct_field()) { + auto mapping = + expr.selection().direct_reference().struct_field().field(); + if (mapping < relationData->fieldReferences.size()) { + auto field = relationData->fieldReferences[mapping]; + relationData->generatedFieldReferences.push_back(field); + auto prevInstance = + relationData->seenFieldReferenceNames.find(field->name); + if (field->alias.empty() && + prevInstance != relationData->seenFieldReferenceNames.end()) { + // Add a version with the schema supplied. + auto schemaName = getSchemaName(field); + if (!schemaName.empty()) { + relationData->generatedFieldReferenceAlternativeExpression + [relationData->generatedFieldReferences.size() - 1] = + schemaName + "." + field->name; + } + // Now update the first occurrence if it hasn't already. + auto schemaNamePrev = getSchemaName( + relationData->generatedFieldReferences[prevInstance->second]); + if (!schemaNamePrev.empty()) { + relationData->generatedFieldReferenceAlternativeExpression + [prevInstance->second] = schemaNamePrev + "." + field->name; + } + } + if (field->alias.empty()) { + relationData->seenFieldReferenceNames.insert(std::make_pair( + field->name, + relationData->generatedFieldReferences.size() - 1)); + } + } else { + ::substrait::proto::Rel sanitizedRelation = *currentRelationScope_; + sanitizedRelation.mutable_project()->clear_input(); + errorListener_->addError( + "Asked to project a field that isn't available - " + + std::to_string(mapping) + " by relation " + + sanitizedRelation.ShortDebugString()); + } + } else { + const auto& uniqueName = + symbolTable_->getUniqueName(kIntermediateNodeName); + auto newSymbol = symbolTable_->defineSymbol( + uniqueName, + PROTO_LOCATION(relation.project()), + SymbolType::kUnknown, + std::nullopt, + std::nullopt); + relationData->generatedFieldReferences.push_back(newSymbol); + symbolTable_->addAlias(uniqueName, newSymbol); + } + } break; + } case ::substrait::proto::Rel::RelTypeCase::kSet: addFieldsToRelation(relationData, relation.set().inputs()); break; @@ -368,8 +524,51 @@ void InitialPlanProtoVisitor::updateLocalSchema( case ::substrait::proto::Rel::REL_TYPE_NOT_SET: break; } - // TODO -- Utilize the data in relation.common().emit() to alter the order - // of the fields that leave this relation. + + // Revamp the output based on the output mapping if present. + auto mapping = getOutputMapping(relation); + if (!mapping.empty()) { + if (!relationData->outputFieldReferences.empty()) { + errorListener_->addError( + "Aggregate relations do not yet support output mapping changes."); + return; + } + for (auto item : mapping) { + auto fieldReferenceSize = relationData->fieldReferences.size(); + if (item < fieldReferenceSize) { + relationData->outputFieldReferences.push_back( + relationData->fieldReferences[item]); + } else if ( + item < + fieldReferenceSize + relationData->generatedFieldReferences.size()) { + relationData->outputFieldReferences.push_back( + relationData->generatedFieldReferences[item - fieldReferenceSize]); + } else { + errorListener_->addError( + "Asked to emit field " + std::to_string(item) + + " which isn't available."); + } + } + } +} + +void InitialPlanProtoVisitor::addGroupingToRelation( + const std::shared_ptr& relationData, + const ::substrait::proto::AggregateRel_Grouping& grouping) { + for (const auto& expr : grouping.grouping_expressions()) { + // TODO -- Add support for groupings made up of complicated expressions. + if (expr.has_selection()) { + auto mapping = expr.selection().direct_reference().struct_field().field(); + // TODO -- Figure out if we need to not add fields we've already seen. + if (mapping >= relationData->fieldReferences.size()) { + errorListener_->addError( + "Grouping attempted to use a field reference not in the input field mapping."); + continue; + } + relationData->generatedFieldReferences.push_back( + relationData->fieldReferences[mapping]); + } + } } } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/converter/InitialPlanProtoVisitor.h b/src/substrait/textplan/converter/InitialPlanProtoVisitor.h index 9fa25dcc..520e8a63 100644 --- a/src/substrait/textplan/converter/InitialPlanProtoVisitor.h +++ b/src/substrait/textplan/converter/InitialPlanProtoVisitor.h @@ -59,7 +59,12 @@ class InitialPlanProtoVisitor : public BasePlanProtoVisitor { void updateLocalSchema( const std::shared_ptr& relationData, - const ::substrait::proto::Rel& relation); + const ::substrait::proto::Rel& relation, + const ::substrait::proto::Rel& internalRelation); + + static void addFieldToRelation( + const std::shared_ptr& relationData, + const SymbolInfo* field); void addFieldsToRelation( const std::shared_ptr& relationData, @@ -79,6 +84,10 @@ class InitialPlanProtoVisitor : public BasePlanProtoVisitor { } }; + void addGroupingToRelation( + const std::shared_ptr& relationData, + const ::substrait::proto::AggregateRel_Grouping& grouping); + std::shared_ptr symbolTable_; std::shared_ptr errorListener_; diff --git a/src/substrait/textplan/converter/PipelineVisitor.cpp b/src/substrait/textplan/converter/PipelineVisitor.cpp index 69909f06..d63a6f56 100644 --- a/src/substrait/textplan/converter/PipelineVisitor.cpp +++ b/src/substrait/textplan/converter/PipelineVisitor.cpp @@ -4,81 +4,76 @@ #include "substrait/textplan/Any.h" #include "substrait/textplan/StructuredSymbolData.h" +#include "substrait/textplan/SymbolTable.h" namespace io::substrait::textplan { -std::shared_ptr PipelineVisitor::getRelationData( - const google::protobuf::Message& relation) { - auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); - if (symbol == nullptr) { - return nullptr; - } - return ANY_CAST(std::shared_ptr, symbol->blob); -} - std::any PipelineVisitor::visitRelation( const ::substrait::proto::Rel& relation) { - auto relationData = getRelationData(relation); + auto symbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation), SymbolType::kRelation); + auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); switch (relation.rel_type_case()) { case ::substrait::proto::Rel::RelTypeCase::kRead: // No relations beyond this one. break; case ::substrait::proto::Rel::RelTypeCase::kFilter: { - const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.filter().input())); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.filter().input()), SymbolType::kRelation); relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kFetch: { - const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.fetch().input())); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.fetch().input()), SymbolType::kRelation); relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kAggregate: { - const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.aggregate().input())); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.aggregate().input()), SymbolType::kRelation); relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kSort: { - const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.sort().input())); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.sort().input()), SymbolType::kRelation); relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kJoin: { - const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.join().left())); - const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.join().right())); + const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.join().left()), SymbolType::kRelation); + const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.join().right()), SymbolType::kRelation); relationData->newPipelines.push_back(leftSymbol); relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::RelTypeCase::kProject: { - const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.project().input())); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.project().input()), SymbolType::kRelation); relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kSet: for (const auto& rel : relation.set().inputs()) { - const auto* inputSymbol = - symbolTable_->lookupSymbolByLocation(Location(&rel)); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(rel), SymbolType::kRelation); relationData->newPipelines.push_back(inputSymbol); } break; case ::substrait::proto::Rel::RelTypeCase::kExtensionSingle: { - const auto* inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.extension_single().input())); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.extension_single().input()), + SymbolType::kRelation); relationData->continuingPipeline = inputSymbol; break; } case ::substrait::proto::Rel::RelTypeCase::kExtensionMulti: for (const auto& rel : relation.extension_multi().inputs()) { - const auto* inputSymbol = - symbolTable_->lookupSymbolByLocation(Location(&rel)); + const auto* inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(rel), SymbolType::kRelation); relationData->newPipelines.push_back(inputSymbol); } break; @@ -86,28 +81,28 @@ std::any PipelineVisitor::visitRelation( // No children. break; case ::substrait::proto::Rel::RelTypeCase::kCross: { - const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.cross().left())); - const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.cross().right())); + const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.cross().left()), SymbolType::kRelation); + const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.cross().right()), SymbolType::kRelation); relationData->newPipelines.push_back(leftSymbol); relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::RelTypeCase::kHashJoin: { - const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.hash_join().left())); - const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.hash_join().right())); + const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.hash_join().left()), SymbolType::kRelation); + const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.hash_join().right()), SymbolType::kRelation); relationData->newPipelines.push_back(leftSymbol); relationData->newPipelines.push_back(rightSymbol); break; } case ::substrait::proto::Rel::RelTypeCase::kMergeJoin: { - const auto* leftSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.merge_join().left())); - const auto* rightSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.merge_join().right())); + const auto* leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.merge_join().left()), SymbolType::kRelation); + const auto* rightSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.merge_join().right()), SymbolType::kRelation); relationData->newPipelines.push_back(leftSymbol); relationData->newPipelines.push_back(rightSymbol); break; @@ -121,18 +116,20 @@ std::any PipelineVisitor::visitRelation( std::any PipelineVisitor::visitPlanRelation( const ::substrait::proto::PlanRel& relation) { - auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); - auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); + auto symbols = + symbolTable_->lookupSymbolsByLocation(PROTO_LOCATION(relation)); + // A symbol is guaranteed as we previously visited the parse tree. + auto relationData = ANY_CAST(std::shared_ptr, symbols[0]->blob); switch (relation.rel_type_case()) { case ::substrait::proto::PlanRel::kRel: { - const auto& relSymbol = - symbolTable_->lookupSymbolByLocation(Location(&relation.rel())); + const auto& relSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.rel()), SymbolType::kRelation); relationData->newPipelines.push_back(relSymbol); break; } case ::substrait::proto::PlanRel::kRoot: { - const auto& inputSymbol = symbolTable_->lookupSymbolByLocation( - Location(&relation.root().input())); + const auto& inputSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.root().input()), SymbolType::kRelation); relationData->newPipelines.push_back(inputSymbol); break; } diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp index 89577819..3243769d 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.cpp +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.cpp @@ -65,6 +65,38 @@ std::string visitEnumArgument(const std::string& str) { return text.str(); } +int32_t expressionCount(const ::substrait::proto::Rel& relation) { + switch (relation.rel_type_case()) { + case ::substrait::proto::Rel::kProject: + return relation.project().expressions().size(); + default: + // No support for any other types besides project at the moment. + break; + } + return 0; +} + +const ::substrait::proto::Expression* getExpressionByNumber( + const ::substrait::proto::Rel& relation, + int num) { + switch (relation.rel_type_case()) { + case ::substrait::proto::Rel::kProject: + return &relation.project().expressions(num); + default: + // No support for any other types besides project at the moment. + break; + } + return nullptr; +} + +bool isDirectFieldReference(const ::substrait::proto::Expression& expr) { + if (expr.selection().reference_type_case() == + ::substrait::proto::Expression::FieldReference::kDirectReference) { + return expr.selection().direct_reference().has_struct_field(); + } + return false; +} + } // namespace std::string PlanPrinterVisitor::printRelation(const SymbolInfo& symbol) { @@ -100,18 +132,40 @@ std::string PlanPrinterVisitor::typeToText( return ANY_CAST(std::string, visitType(type)); } -std::string PlanPrinterVisitor::lookupFieldReference(uint32_t field_reference) { - if (*currentScope_ != SymbolInfo::kUnknown) { - auto relationData = - ANY_CAST(std::shared_ptr, currentScope_->blob); - if (field_reference < relationData->fieldReferences.size()) { - return relationData->fieldReferences[field_reference]->name; - } +// TODO -- Refactor this to return the symbol for later display decisions. +std::string PlanPrinterVisitor::lookupFieldReference( + uint32_t field_reference, + bool needFullyQualified) { + if (*currentScope_ == SymbolInfo::kUnknown) { + errorListener_->addError( + "Field number " + std::to_string(field_reference) + + " mysteriously requested outside of a relation."); + return "field#" + std::to_string(field_reference); + } + auto relationData = + ANY_CAST(std::shared_ptr, currentScope_->blob); + auto fieldReferencesSize = relationData->fieldReferences.size(); + const SymbolInfo* symbol{nullptr}; + if (field_reference < fieldReferencesSize) { + symbol = relationData->fieldReferences[field_reference]; + } else if ( + field_reference < + fieldReferencesSize + relationData->generatedFieldReferences.size()) { + symbol = + relationData + ->generatedFieldReferences[field_reference - fieldReferencesSize]; + } else { + errorListener_->addError( + "Encountered field reference out of range: " + + std::to_string(field_reference)); + return "field#" + std::to_string(field_reference); } - errorListener_->addError( - "Field number " + std::to_string(field_reference) + - " referenced but not defined."); - return "field#" + std::to_string(field_reference); + if (!symbol->alias.empty()) { + return symbol->alias; + } else if (needFullyQualified && symbol->schema != nullptr) { + return symbol->schema->name + "." + symbol->name; + } + return symbol->name; } std::string PlanPrinterVisitor::lookupFunctionReference( @@ -357,8 +411,9 @@ std::any PlanPrinterVisitor::visitFieldReference( // TODO -- Move this logic into visitDirectReference and visitMaskedReference. switch (ref.reference_type_case()) { case ::substrait::proto::Expression::FieldReference::kDirectReference: + // TODO -- Figure out when fully qualified names aren't needed. return lookupFieldReference( - ref.direct_reference().struct_field().field()); + ref.direct_reference().struct_field().field(), true); case ::substrait::proto::Expression::FieldReference::kMaskedReference: errorListener_->addError( "Masked reference not yet supported: " + ref.ShortDebugString()); @@ -421,15 +476,17 @@ std::any PlanPrinterVisitor::visitScalarFunction( text << ANY_CAST(std::string, visitExpression(arg)); hasPreviousText = true; } - // TODO -- Determine if the output type can be automatically determined. - // text << "->" << visitType(function.output_type()); - if (!hasPreviousText) { errorListener_->addError( "Function encountered without any arguments: " + function.ShortDebugString()); } + text << ")"; + + // TODO -- Determine if the output type can be automatically determined. + text << "->" << typeToText(function.output_type()); + return text.str(); } @@ -577,6 +634,18 @@ std::any PlanPrinterVisitor::visitFileOrFiles( return std::string("FORF_NOT_YET_IMPLEMENTED"); } +std::any PlanPrinterVisitor::visitRelationCommon( + const ::substrait::proto::RelCommon& common) { + std::stringstream text; + if (common.emit().output_mapping_size() > 0) { + text << "\n"; + } + for (const auto& mapping : common.emit().output_mapping()) { + text << " emit " << lookupFieldReference(mapping, true) << ";\n"; + } + return text.str(); +} + std::any PlanPrinterVisitor::visitAggregateFunction( const ::substrait::proto::AggregateFunction& function) { std::stringstream text; @@ -627,6 +696,7 @@ std::any PlanPrinterVisitor::visitAggregateFunction( } text << "->" << ANY_CAST(std::string, visitType(function.output_type())); // TODO -- Emit the requested sort behavior here. + // TODO -- Consider removing the AGGREGATION_PHASE_ prefix. text << "@" << ::substrait::proto::AggregationPhase_Name(function.phase()); return text.str(); } @@ -657,8 +727,8 @@ std::any PlanPrinterVisitor::visitRelation( // Mark the current scope for any operations within this relation. auto previousScope = currentScope_; auto resetCurrentScope = finally([&]() { currentScope_ = previousScope; }); - const SymbolInfo* symbol = - symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation)); + const SymbolInfo* symbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation), SymbolType::kRelation); if (symbol != nullptr) { currentScope_ = symbol; } @@ -689,15 +759,16 @@ std::any PlanPrinterVisitor::visitReadRelation( case ::substrait::proto::ReadRel::READ_TYPE_NOT_SET: return ""; } - const auto* symbol = - symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(*msg)); - if (symbol != nullptr) { - text << " source " << symbol->name << ";\n"; + + auto source = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(*msg), SymbolType::kSource); + if (source != nullptr) { + text << " source " << source->name << ";\n"; } if (relation.has_base_schema()) { - const auto* schemaSymbol = symbolTable_->lookupSymbolByLocation( - PROTO_LOCATION(relation.base_schema())); + const auto* schemaSymbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(relation.base_schema()), SymbolType::kSchema); if (schemaSymbol != nullptr) { text << " base_schema " << schemaSymbol->name << ";\n"; } @@ -753,12 +824,19 @@ std::any PlanPrinterVisitor::visitAggregateRelation( } for (const auto& measure : relation.measures()) { if (!measure.has_measure()) { + errorListener_->addError( + "Encountered aggregate measure without a measure function."); continue; } text << " measure {\n"; text << " measure " - << ANY_CAST(std::string, visitAggregateFunction(measure.measure())) - << ";\n"; + << ANY_CAST(std::string, visitAggregateFunction(measure.measure())); + auto symbol = symbolTable_->lookupSymbolByLocationAndType( + PROTO_LOCATION(measure), SymbolType::kMeasure); + if (symbol != nullptr) { + text << " NAMED " << symbol->name; + } + text << ";\n"; if (measure.has_filter()) { text << " filter " << ANY_CAST(std::string, visitExpression(measure.filter())) << ";\n"; @@ -824,16 +902,37 @@ std::any PlanPrinterVisitor::visitSortRelation( std::any PlanPrinterVisitor::visitProjectRelation( const ::substrait::proto::ProjectRel& relation) { std::stringstream text; + auto relationData = + ANY_CAST(std::shared_ptr, currentScope_->blob); + int exprNum = 0; for (const auto& expr : relation.expressions()) { - text << " expression " << ANY_CAST(std::string, visitExpression(expr)) - << ";\n"; + text << " expression "; + if (relationData->generatedFieldReferenceAlternativeExpression.find( + exprNum) != + relationData->generatedFieldReferenceAlternativeExpression.end()) { + text << relationData + ->generatedFieldReferenceAlternativeExpression[exprNum]; + } else if ( + exprNum < relationData->generatedFieldReferences.size() && + !relationData->generatedFieldReferences[exprNum]->alias.empty()) { + text << ANY_CAST(std::string, visitExpression(expr)); + text << " NAMED " + << relationData->generatedFieldReferences[exprNum]->alias; + } else { + text << ANY_CAST(std::string, visitExpression(expr)); + } + + text << ";\n"; + exprNum++; } + text << ANY_CAST(std::string, visitRelationCommon(relation.common())); return text.str(); } std::any PlanPrinterVisitor::visitJoinRelation( const ::substrait::proto::JoinRel& relation) { std::stringstream text; + // TODO -- Consider removing the JOIN_TYPE_ prefix. text << " type " << ::substrait::proto::JoinRel_JoinType_Name(relation.type()) << ";\n"; if (relation.has_expression()) { @@ -849,6 +948,12 @@ std::any PlanPrinterVisitor::visitJoinRelation( return text.str(); } +std::any PlanPrinterVisitor::visitCrossRelation( + const ::substrait::proto::CrossRel& relation) { + // There are no custom details in a cross relation. + return std::string{""}; +} + } // namespace io::substrait::textplan #pragma clang diagnostic pop diff --git a/src/substrait/textplan/converter/PlanPrinterVisitor.h b/src/substrait/textplan/converter/PlanPrinterVisitor.h index 2be16b09..f817af58 100644 --- a/src/substrait/textplan/converter/PlanPrinterVisitor.h +++ b/src/substrait/textplan/converter/PlanPrinterVisitor.h @@ -34,7 +34,9 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { std::string typeToText(const ::substrait::proto::Type& type); private: - std::string lookupFieldReference(uint32_t field_reference); + std::string lookupFieldReference( + uint32_t field_reference, + bool needFullyQualified); std::string lookupFunctionReference(uint32_t function_reference); std::any visitSelect( @@ -87,6 +89,8 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { std::any visitReferenceSegment( const ::substrait::proto::Expression_ReferenceSegment& segment) override; + std::any visitRelationCommon( + const ::substrait::proto::RelCommon& common) override; std::any visitAggregateFunction( const ::substrait::proto::AggregateFunction& function) override; std::any visitExpression( @@ -111,6 +115,8 @@ class PlanPrinterVisitor : public BasePlanProtoVisitor { const ::substrait::proto::ProjectRel& relation) override; std::any visitJoinRelation( const ::substrait::proto::JoinRel& relation) override; + std::any visitCrossRelation( + const ::substrait::proto::CrossRel& relation) override; std::shared_ptr symbolTable_; std::shared_ptr errorListener_; diff --git a/src/substrait/textplan/converter/ReferenceNormalizer.cpp b/src/substrait/textplan/converter/ReferenceNormalizer.cpp index f2c7d849..ec667c1f 100644 --- a/src/substrait/textplan/converter/ReferenceNormalizer.cpp +++ b/src/substrait/textplan/converter/ReferenceNormalizer.cpp @@ -26,7 +26,7 @@ bool compareExtensionFunctions( }; // Now let the default tuple compare do the rest of the work. - return ord(a) > ord(b); + return ord(a) < ord(b); } void normalizeFunctionsForExpression( diff --git a/src/substrait/textplan/converter/data/q6_first_stage.golden.splan b/src/substrait/textplan/converter/data/q6_first_stage.golden.splan new file mode 100644 index 00000000..ffb3b3a3 --- /dev/null +++ b/src/substrait/textplan/converter/data/q6_first_stage.golden.splan @@ -0,0 +1,73 @@ +pipelines { + read -> filter -> project -> aggregate -> root; +} + +read relation read { + source local; + base_schema schema; + filter and( + and( + and( + and( + and( + and( + and( + is_not_null(schema.l_shipdate_new)->bool?, + is_not_null(schema.l_discount)->bool?)->bool?, + is_not_null(schema.l_quantity)->bool?)->bool?, + gte(schema.l_shipdate_new, 8766_fp64)->bool?)->bool?, + lt(schema.l_shipdate_new, 9131_fp64)->bool?)->bool?, + gte(schema.l_discount, 0.05_fp64)->bool?)->bool?, + lte(schema.l_discount, 0.07_fp64)->bool?)->bool?, + lt(schema.l_quantity, 24_fp64)->bool?)->bool?; +} + +filter relation filter { + filter and( + and( + and( + and( + gte(schema.l_shipdate_new, 8766_fp64)->bool?, + lt(schema.l_shipdate_new, 9131_fp64)->bool?)->bool?, + gte(schema.l_discount, 0.05_fp64)->bool?)->bool?, + lte(schema.l_discount, 0.07_fp64)->bool?)->bool?, + lt(schema.l_quantity, 24_fp64)->bool?)->bool?; +} + +project relation project { + expression schema.l_extendedprice; + expression schema.l_discount; + + emit schema.l_extendedprice; + emit schema.l_discount; +} + +aggregate relation aggregate { + measure { + measure sum( + multiply(schema.l_extendedprice, schema.l_discount)->fp64?)->fp64?@AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE NAMED measurename; + } +} + +schema schema { + l_quantity fp64?; + l_extendedprice fp64?; + l_discount fp64?; + l_shipdate_new fp64?; +} + +source local_files local { + items = [ + {uri_file: "/mock_lineitem.orc" start: 0 length: 3719 orc: {}} + ] +} + +extension_space { + function and:bool_bool as and; + function gte:fp64_fp64 as gte; + function is_not_null:fp64 as is_not_null; + function lt:fp64_fp64 as lt; + function lte:fp64_fp64 as lte; + function multiply:opt_fp64_fp64 as multiply; + function sum:opt_fp64 as sum; +} diff --git a/src/substrait/textplan/converter/data/q6_first_stage.json b/src/substrait/textplan/converter/data/q6_first_stage.json index 1985ad74..6f8a2010 100644 --- a/src/substrait/textplan/converter/data/q6_first_stage.json +++ b/src/substrait/textplan/converter/data/q6_first_stage.json @@ -62,7 +62,9 @@ "input": { "project": { "common": { - "direct": {} + "emit": { + "outputMapping": [4, 5] + } }, "input": { "filter": { @@ -109,7 +111,7 @@ } ], "type_variation_reference": 0, - "nullability": "NULLABILITY_UNSPECIFIED" + "nullability": "NULLABILITY_REQUIRED" } }, "filter": { diff --git a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp index f15796d1..cde3c92a 100644 --- a/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp +++ b/src/substrait/textplan/converter/tests/BinaryToTextPlanConversionTest.cpp @@ -196,6 +196,8 @@ std::vector getTestCases() { "count", "named", "#2", + "cost", + "count", "read", "root"}), WhenSerialized(EqSquashingWhitespace( @@ -321,8 +323,8 @@ std::vector getTestCases() { join relation join { type JOIN_TYPE_UNSPECIFIED; - expression product_id; - post_join count; + expression schema.product_id; + post_join schema.count; } read relation read3 { @@ -332,7 +334,7 @@ std::vector getTestCases() { join relation join2 { type JOIN_TYPE_UNSPECIFIED; - expression order_id; + expression schema3.order_id; } schema schema { @@ -415,7 +417,7 @@ std::vector getTestCases() { } filter relation filter { - filter functionref#4(field#2, 0.07_fp64); + filter functionref#4(field#2, 0.07_fp64)->bool?; })"))), }, { @@ -468,7 +470,7 @@ std::vector getTestCases() { } filter relation filter { - filter functionref#4(field#2, 0.07_fp64); + filter functionref#4(field#2, 0.07_fp64)->bool?; })"))), }, { @@ -624,13 +626,15 @@ INSTANTIATE_TEST_SUITE_P( class BinaryToTextPlanConversionTest : public ::testing::Test {}; -TEST_F(BinaryToTextPlanConversionTest, loadFromJSON) { +TEST_F(BinaryToTextPlanConversionTest, FullSample) { std::string json = readFromFile("data/q6_first_stage.json"); auto planOrError = loadFromJson(json); ASSERT_TRUE(planOrError.ok()); auto plan = *planOrError; EXPECT_THAT(plan.extensions_size(), ::testing::Eq(7)); + std::string expectedOutput = readFromFile("data/q6_first_stage.golden.splan"); + auto result = parseBinaryPlan(plan); auto symbols = result.getSymbolTable().getSymbols(); ASSERT_THAT( @@ -663,7 +667,9 @@ TEST_F(BinaryToTextPlanConversionTest, loadFromJSON) { SymbolType::kRelation, SymbolType::kSource, SymbolType::kSchema, - }))); + }), + WhenSerialized(EqSquashingWhitespace(expectedOutput)))) + << result.getSymbolTable().toDebugString(); } } // namespace diff --git a/src/substrait/textplan/converter/tests/CMakeLists.txt b/src/substrait/textplan/converter/tests/CMakeLists.txt index e120fe5f..8d285f11 100644 --- a/src/substrait/textplan/converter/tests/CMakeLists.txt +++ b/src/substrait/textplan/converter/tests/CMakeLists.txt @@ -23,7 +23,11 @@ add_custom_command( "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data" COMMAND ${CMAKE_COMMAND} -E copy "${TEXTPLAN_SOURCE_DIR}/data/q6_first_stage.json" - "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json") + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.json" + COMMAND + ${CMAKE_COMMAND} -E copy + "${TEXTPLAN_SOURCE_DIR}/data/q6_first_stage.golden.splan" + "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/tests/data/q6_first_stage.golden.splan") message( STATUS diff --git a/src/substrait/textplan/data/tpch-plan01.json b/src/substrait/textplan/data/tpch-plan01.json index bb975373..877b4b73 100644 --- a/src/substrait/textplan/data/tpch-plan01.json +++ b/src/substrait/textplan/data/tpch-plan01.json @@ -56,7 +56,7 @@ "extensionFunction": { "extensionUriReference": 3, "functionAnchor": 7, - "name": "count:" + "name": "count:any" } }], "relations": [{ diff --git a/src/substrait/textplan/data/tpch-plan13.json b/src/substrait/textplan/data/tpch-plan13.json index 5c1d750a..1a0dd534 100644 --- a/src/substrait/textplan/data/tpch-plan13.json +++ b/src/substrait/textplan/data/tpch-plan13.json @@ -47,7 +47,7 @@ "extensionFunction": { "extensionUriReference": 4, "functionAnchor": 5, - "name": "count:" + "name": "count:any" } }], "relations": [{ diff --git a/src/substrait/textplan/parser/ParseText.h b/src/substrait/textplan/parser/ParseText.h index 61a45e39..caaf7b3b 100644 --- a/src/substrait/textplan/parser/ParseText.h +++ b/src/substrait/textplan/parser/ParseText.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "substrait/textplan/ParseResult.h" #include "substrait/textplan/SymbolTable.h" #include "substrait/textplan/parser/SubstraitPlanVisitor.h" diff --git a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp index 806d23b1..5e0e1806 100644 --- a/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanPipelineVisitor.cpp @@ -95,25 +95,24 @@ std::any SubstraitPlanPipelineVisitor::visitPipeline( // a -> b -> c -> d const SymbolInfo* leftSymbol = &SymbolInfo::kUnknown; if (ctx->pipeline() != nullptr) { - leftSymbol = - symbolTable_->lookupSymbolByLocation(PARSER_LOCATION(ctx->pipeline())); + leftSymbol = symbolTable_->lookupSymbolByLocationAndType( + PARSER_LOCATION(ctx->pipeline()), SymbolType::kRelation); } const SymbolInfo* rightSymbol = &SymbolInfo::kUnknown; if (dynamic_cast(ctx->parent)->getRuleIndex() == SubstraitPlanParser::RulePipeline) { - rightSymbol = - symbolTable_->lookupSymbolByLocation(PARSER_LOCATION(ctx->parent)); + rightSymbol = symbolTable_->lookupSymbolByLocationAndTypes( + PARSER_LOCATION(ctx->parent), + {SymbolType::kRelation, SymbolType::kRoot}); } const SymbolInfo* rightmostSymbol = rightSymbol; if (*rightSymbol != SymbolInfo::kUnknown) { - if (rightSymbol->blob.type() != typeid(std::shared_ptr)) { - errorListener_->addError( - ctx->getStart(), "No relation definition present for this symbol."); - } - auto rightRelationData = - ANY_CAST(std::shared_ptr, rightSymbol->blob); - if (rightRelationData->pipelineStart != nullptr) { - rightmostSymbol = rightRelationData->pipelineStart; + if (rightSymbol->blob.type() == typeid(std::shared_ptr)) { + auto rightRelationData = + ANY_CAST(std::shared_ptr, rightSymbol->blob); + if (rightRelationData->pipelineStart != nullptr) { + rightmostSymbol = rightRelationData->pipelineStart; + } } } diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp index abdb98b2..a8273f12 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.cpp @@ -20,6 +20,7 @@ #include "substrait/textplan/Any.h" #include "substrait/textplan/Finally.h" #include "substrait/textplan/Location.h" +#include "substrait/textplan/StringManipulation.h" #include "substrait/textplan/StructuredSymbolData.h" #include "substrait/textplan/SymbolTable.h" @@ -27,9 +28,12 @@ namespace io::substrait::textplan { namespace { -std::string kAggregationPhasePrefix = "aggregationphase"; -std::string kAggregationInvocationPrefix = "aggregationinvocation"; -std::string kSortDirectionPrefix = "sortdirection"; +const std::string kAggregationPhasePrefix = "aggregationphase"; +const std::string kAggregationInvocationPrefix = "aggregationinvocation"; +const std::string kJoinTypePrefix = "jointype"; +const std::string kSortDirectionPrefix = "sortdirection"; + +const std::string kIntermediateNodeName = "intermediate"; enum RelationFilterBehavior { kDefault = 0, @@ -45,18 +49,6 @@ std::string toLower(const std::string& str) { return s; } -// Yields true if the string 'haystack' starts with the string 'needle'. -bool startsWith(std::string_view haystack, std::string_view needle) { - return haystack.size() > needle.size() && - haystack.substr(0, needle.size()) == needle; -} - -// Returns true if the string 'haystack' ends with the string 'needle'. -bool endsWith(std::string_view haystack, std::string_view needle) { - return haystack.size() > needle.size() && - haystack.substr(haystack.size() - needle.size(), needle.size()) == needle; -} - void setNullable(::substrait::proto::Type* type) { switch (type->kind_case()) { case ::substrait::proto::Type::kBool: @@ -217,6 +209,47 @@ void setRelationType( } } +::substrait::proto::RelCommon* findCommonRelation( + RelationType relationType, + ::substrait::proto::Rel* relation) { + switch (relationType) { + case RelationType::kRead: + return relation->mutable_read()->mutable_common(); + case RelationType::kProject: + return relation->mutable_project()->mutable_common(); + case RelationType::kJoin: + return relation->mutable_join()->mutable_common(); + case RelationType::kCross: + return relation->mutable_cross()->mutable_common(); + case RelationType::kFetch: + return relation->mutable_fetch()->mutable_common(); + case RelationType::kAggregate: + return relation->mutable_aggregate()->mutable_common(); + case RelationType::kSort: + return relation->mutable_sort()->mutable_common(); + case RelationType::kFilter: + return relation->mutable_filter()->mutable_common(); + case RelationType::kSet: + return relation->mutable_set()->mutable_common(); + case RelationType::kExtensionLeaf: + return relation->mutable_extension_leaf()->mutable_common(); + case RelationType::kExtensionMulti: + return relation->mutable_extension_multi()->mutable_common(); + case RelationType::kExtensionSingle: + return relation->mutable_extension_single()->mutable_common(); + case RelationType::kHashJoin: + return relation->mutable_hash_join()->mutable_common(); + case RelationType::kMergeJoin: + return relation->mutable_merge_join()->mutable_common(); + case RelationType::kExchange: + case RelationType::kDdl: + case RelationType::kWrite: + case RelationType::kUnknown: + break; + } + return nullptr; +} + std::string normalizeProtoEnum(std::string_view text, std::string_view prefix) { std::string result{text}; // Remove non-alphabetic characters. @@ -238,6 +271,51 @@ std::string normalizeProtoEnum(std::string_view text, std::string_view prefix) { return result; } +void addInputFieldsToSchema( + RelationType relationType, + std::shared_ptr& relationData) { + if (relationData->continuingPipeline != nullptr) { + auto continuingRelationData = ANY_CAST( + std::shared_ptr, relationData->continuingPipeline->blob); + if (!continuingRelationData->outputFieldReferences.empty()) { + // There is an emit sequence so use that. + for (auto field : continuingRelationData->outputFieldReferences) { + relationData->fieldReferences.push_back(field); + } + } else { + // There was no emit so just access all the field references. + for (auto field : continuingRelationData->fieldReferences) { + relationData->fieldReferences.push_back(field); + } + for (auto field : continuingRelationData->generatedFieldReferences) { + relationData->fieldReferences.push_back(field); + } + } + } + + for (auto pipeline : relationData->newPipelines) { + auto pipelineRelationData = + ANY_CAST(std::shared_ptr, pipeline->blob); + if (!pipelineRelationData->outputFieldReferences.empty()) { + for (auto field : pipelineRelationData->outputFieldReferences) { + relationData->fieldReferences.push_back(field); + } + } else { + for (auto field : pipelineRelationData->fieldReferences) { + relationData->fieldReferences.push_back(field); + } + for (auto field : pipelineRelationData->generatedFieldReferences) { + relationData->fieldReferences.push_back(field); + } + } + } +} + +bool isRelationEmitDetail(SubstraitPlanParser::Relation_detailContext* ctx) { + return dynamic_cast(ctx) != + nullptr; +} + } // namespace std::any SubstraitPlanRelationVisitor::aggregateResult( @@ -252,15 +330,14 @@ std::any SubstraitPlanRelationVisitor::aggregateResult( std::any SubstraitPlanRelationVisitor::visitRelation( SubstraitPlanParser::RelationContext* ctx) { - // Create the relation before visiting our children, so they can update it. - auto* symbol = symbolTable_->lookupSymbolByLocation(Location(ctx)); + // First find the relation created in a previous step. + auto* symbol = symbolTable_->lookupSymbolByLocationAndType( + Location(ctx), SymbolType::kRelation); if (symbol == nullptr) { // This error has been previously dealt with thus we can safely skip it. return defaultResult(); } - if (symbol->type == SymbolType::kRoot) { - return defaultResult(); - } + // Create the relation data before visiting children, so they can update it. auto relationData = ANY_CAST(std::shared_ptr, symbol->blob); ::substrait::proto::Rel relation; @@ -276,8 +353,95 @@ std::any SubstraitPlanRelationVisitor::visitRelation( finally([&]() { currentRelationScope_ = previousScope; }); currentRelationScope_ = symbol; - visitChildren(ctx); + addInputFieldsToSchema(relationType, relationData); + + // Visit everything but the emit details to gather necessary information. + for (auto detail : ctx->relation_detail()) { + if (!isRelationEmitDetail(detail)) { + visitRelationDetail(detail); + } + } + + addExpressionsToSchema(relationData); + + // Now visit the emit details. + for (auto detail : ctx->relation_detail()) { + if (isRelationEmitDetail(detail)) { + visitRelationDetail(detail); + } + } + + // Aggregate relations are different in that they alter the emitted fields + // by default. + if (relationType == RelationType::kAggregate) { + relationData->outputFieldReferences.insert( + relationData->outputFieldReferences.end(), + relationData->generatedFieldReferences.begin(), + relationData->generatedFieldReferences.end()); + } + + applyOutputMappingToSchema(ctx->getStart(), relationType, relationData); + + // Emit one empty grouping for an aggregation relation not specifying any. + if (relationType == RelationType::kAggregate && + relationData->relation.aggregate().groupings_size() == 0) { + relationData->relation.mutable_aggregate()->add_groupings(); + } + return defaultResult(); +} +std::any SubstraitPlanRelationVisitor::visitRelationDetail( + SubstraitPlanParser::Relation_detailContext* ctx) { + if (auto* commonCtx = + dynamic_cast(ctx)) { + return visitRelationCommon(commonCtx); + } else if ( + auto* usesSchemaCtx = + dynamic_cast(ctx)) { + return visitRelationUsesSchema(usesSchemaCtx); + } else if ( + auto* filterCtx = + dynamic_cast(ctx)) { + return visitRelationFilter(filterCtx); + } else if ( + auto* exprCtx = + dynamic_cast(ctx)) { + return visitRelationExpression(exprCtx); + } else if ( + auto* advExtensionCtx = + dynamic_cast( + ctx)) { + return visitRelationAdvancedExtension(advExtensionCtx); + } else if ( + auto* sourceRefCtx = + dynamic_cast( + ctx)) { + return visitRelationSourceReference(sourceRefCtx); + } else if ( + auto* groupingCtx = + dynamic_cast(ctx)) { + return visitRelationGrouping(groupingCtx); + } else if ( + auto* measureCtx = + dynamic_cast(ctx)) { + return visitRelationMeasure(measureCtx); + } else if ( + auto* sortCtx = + dynamic_cast(ctx)) { + return visitRelationSort(sortCtx); + } else if ( + auto* countCtx = + dynamic_cast(ctx)) { + return visitRelationCount(countCtx); + } else if ( + auto* joinTypeCtx = + dynamic_cast(ctx)) { + return visitRelationJoinType(joinTypeCtx); + } else if ( + auto* emitCtx = + dynamic_cast(ctx)) { + return visitRelationEmit(emitCtx); + } return defaultResult(); } @@ -313,11 +477,11 @@ std::any SubstraitPlanRelationVisitor::visitRelationFilter( visitRelation_filter_behavior(ctx->relation_filter_behavior())); } - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + PARSER_LOCATION(ctx->parent), SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); - auto result = SubstraitPlanRelationVisitor::visitChildren(ctx); + auto result = visitChildren(ctx); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); switch (parentRelationType) { case RelationType::kRead: @@ -407,8 +571,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationFilter( std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( SubstraitPlanParser::RelationUsesSchemaContext* ctx) { - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); @@ -426,10 +591,14 @@ std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( if (sym.location != symbol->location) { continue; } + parentRelationData->outputFieldReferences.push_back(&sym); schema->add_names(sym.name); auto typeProto = ANY_CAST(::substrait::proto::Type, sym.blob); if (typeProto.kind_case() != ::substrait::proto::Type::KIND_NOT_SET) { *schema->mutable_struct_()->add_types() = typeProto; + // If the schema contains any types, the struct is required. + schema->mutable_struct_()->set_nullability( + ::substrait::proto::Type_Nullability_NULLABILITY_REQUIRED); } } } @@ -443,11 +612,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationUsesSchema( std::any SubstraitPlanRelationVisitor::visitRelationExpression( SubstraitPlanParser::RelationExpressionContext* ctx) { - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); - auto result = SubstraitPlanRelationVisitor::visitChildren(ctx); + auto result = visitChildren(ctx); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); switch (parentRelationType) { case RelationType::kJoin: @@ -476,11 +646,12 @@ std::any SubstraitPlanRelationVisitor::visitRelationExpression( std::any SubstraitPlanRelationVisitor::visitRelationGrouping( SubstraitPlanParser::RelationGroupingContext* ctx) { - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); - auto result = SubstraitPlanRelationVisitor::visitChildren(ctx); + auto result = visitChildren(ctx); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); switch (parentRelationType) { case RelationType::kAggregate: { @@ -494,6 +665,13 @@ std::any SubstraitPlanRelationVisitor::visitRelationGrouping( *newExpr = ANY_CAST(::substrait::proto::Expression, result); if (newExpr->has_selection()) { newExpr->mutable_selection()->mutable_root_reference(); + if (newExpr->selection().direct_reference().has_struct_field()) { + parentRelationData->generatedFieldReferences.push_back( + parentRelationData->fieldReferences[newExpr->selection() + .direct_reference() + .struct_field() + .field()]); + } } break; } @@ -550,8 +728,9 @@ std::any SubstraitPlanRelationVisitor::visitRelationMeasure( } // Add it to our relation. - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); @@ -569,6 +748,75 @@ std::any SubstraitPlanRelationVisitor::visitRelationMeasure( return defaultResult(); } +std::any SubstraitPlanRelationVisitor::visitRelationJoinType( + SubstraitPlanParser::RelationJoinTypeContext* ctx) { + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); + if (parentRelationType == RelationType::kJoin) { + std::string text = + normalizeProtoEnum(ctx->id()->getText(), kJoinTypePrefix); + ::substrait::proto::JoinRel_JoinType joinType; + if (text == "unspecified") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_UNSPECIFIED; + } else if (text == "inner") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_INNER; + } else if (text == "outer") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_OUTER; + } else if (text == "left") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_LEFT; + } else if (text == "right") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_RIGHT; + } else if (text == "semi") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_SEMI; + } else if (text == "anti") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_ANTI; + } else if (text == "single") { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_SINGLE; + } else { + joinType = ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_UNSPECIFIED; + } + if (joinType == + ::substrait::proto::JoinRel_JoinType_JOIN_TYPE_UNSPECIFIED) { + this->errorListener_->addError( + ctx->getStart(), + "Unsupported join type direction: " + ctx->id()->getText()); + } + parentRelationData->relation.mutable_join()->set_type(joinType); + + // TODO -- Add support for HashJoin/MergeJoin which have different enums. + } else { + errorListener_->addError( + ctx->getStart(), + "Join types are not supported for this relation type."); + return defaultResult(); + } + return defaultResult(); +} + +std::any SubstraitPlanRelationVisitor::visitRelationEmit( + SubstraitPlanParser::RelationEmitContext* ctx) { + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto result = visitChildren(ctx); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); + auto common = + findCommonRelation(parentRelationType, &parentRelationData->relation); + if (common == nullptr) { + errorListener_->addError( + ctx->getStart(), "Emits do not make sense for this kind of relation."); + return defaultResult(); + } + common->mutable_emit()->add_output_mapping(ANY_CAST(int32_t, result)); + return defaultResult(); +} + int32_t SubstraitPlanRelationVisitor::visitAggregationInvocation( SubstraitPlanParser::IdContext* ctx) { std::string text = @@ -635,17 +883,34 @@ std::any SubstraitPlanRelationVisitor::visitMeasure_detail( ::substrait::proto::Type, visitLiteral_complex_type(ctx->literal_complex_type())); } - if (ctx->id() != nullptr) { + if (ctx->id(0) != nullptr) { measure.mutable_measure()->set_phase( static_cast<::substrait::proto::AggregationPhase>( - visitAggregationPhase(ctx->id()))); + visitAggregationPhase(ctx->id(0)))); } } else { errorListener_->addError( - ctx->id()->getStart(), + ctx->id(0)->getStart(), "Expected an expression utilizing a function here."); } - + // If we have a NAMED clause, add a symbol reference. + if (ctx->id().size() > 1) { + auto symbol = symbolTable_->defineSymbol( + ctx->id(1)->getText(), + PROTO_LOCATION(measure), + SymbolType::kMeasure, + std::nullopt, + std::nullopt); + + // Add it to our generated field mapping. + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location( + dynamic_cast(ctx->parent->parent)), + SymbolType::kRelation); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + parentRelationData->generatedFieldReferences.push_back(symbol); + } return measure; } case SubstraitPlanParser::FILTER: @@ -656,7 +921,7 @@ std::any SubstraitPlanRelationVisitor::visitMeasure_detail( measure.mutable_measure()->set_invocation( static_cast< ::substrait::proto::AggregateFunction_AggregationInvocation>( - visitAggregationInvocation(ctx->id()))); + visitAggregationInvocation(ctx->id(0)))); return measure; case SubstraitPlanParser::SORT: *measure.mutable_measure()->add_sorts() = ANY_CAST( @@ -670,16 +935,43 @@ std::any SubstraitPlanRelationVisitor::visitMeasure_detail( std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( SubstraitPlanParser::RelationSourceReferenceContext* ctx) { - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); - if (parentRelationType == RelationType::kRead) { - auto sourceName = ctx->source_reference()->id()->getText(); - auto* symbol = symbolTable_->lookupSymbolByName(sourceName); - if (symbol != nullptr) { + if (parentRelationType != RelationType::kRead) { + errorListener_->addError( + ctx->getStart(), + "Source references are not defined for this kind of relation."); + return defaultResult(); + } + + auto sourceName = ctx->source_reference()->id()->getText(); + auto* symbol = symbolTable_->lookupSymbolByName(sourceName); + if (symbol == nullptr) { + return defaultResult(); + } + switch (ANY_CAST(SourceType, symbol->subtype)) { + case SourceType::kLocalFiles: { + auto* source = + parentRelationData->relation.mutable_read()->mutable_local_files(); + for (const auto& sym : *symbolTable_) { + if (sym.type != SymbolType::kSourceDetail) { + continue; + } + if (sym.location != symbol->location) { + continue; + } + *source->add_items() = *ANY_CAST( + std::shared_ptr<::substrait::proto::ReadRel_LocalFiles_FileOrFiles>, + sym.blob); + } + break; + } + case SourceType::kNamedTable: { auto* source = parentRelationData->relation.mutable_read()->mutable_named_table(); for (const auto& sym : *symbolTable_) { @@ -691,19 +983,26 @@ std::any SubstraitPlanRelationVisitor::visitRelationSourceReference( } source->add_names(sym.name); } + break; } - } else { - errorListener_->addError( - ctx->getStart(), - "Source references are not defined for this kind of relation."); + case SourceType::kVirtualTable: + // TODO -- Implement. + break; + case SourceType::kExtensionTable: + // TODO -- Implement. + break; + case SourceType::kUnknown: + break; } + return defaultResult(); } std::any SubstraitPlanRelationVisitor::visitRelationSort( SubstraitPlanParser::RelationSortContext* ctx) { - auto* parentSymbol = symbolTable_->lookupSymbolByLocation( - Location(dynamic_cast(ctx->parent))); + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); auto parentRelationData = ANY_CAST(std::shared_ptr, parentSymbol->blob); auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); @@ -721,6 +1020,31 @@ std::any SubstraitPlanRelationVisitor::visitRelationSort( return defaultResult(); } +std::any SubstraitPlanRelationVisitor::visitRelationCount( + SubstraitPlanParser::RelationCountContext* ctx) { + auto* parentSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(dynamic_cast(ctx->parent)), + SymbolType::kRelation); + auto parentRelationData = + ANY_CAST(std::shared_ptr, parentSymbol->blob); + auto parentRelationType = ANY_CAST(RelationType, parentSymbol->subtype); + switch (parentRelationType) { + case RelationType::kFetch: { + ::substrait::proto::Type type; + type.mutable_i64()->set_nullability( + ::substrait::proto::Type_Nullability_NULLABILITY_REQUIRED); + auto number = visitNumber(ctx->NUMBER(), type); + parentRelationData->relation.mutable_fetch()->set_count(number.i64()); + break; + } + default: + errorListener_->addError( + ctx->getStart(), "Count only applies to fetch relations."); + break; + } + return defaultResult(); +} + std::any SubstraitPlanRelationVisitor::visitExpression( SubstraitPlanParser::ExpressionContext* ctx) { if (auto* funcUseCtx = @@ -804,6 +1128,12 @@ std::any SubstraitPlanRelationVisitor::visitExpressionFunctionUse( auto newExpr = ANY_CAST(::substrait::proto::Expression, result); *expr.mutable_scalar_function()->add_arguments()->mutable_value() = newExpr; } + if (ctx->literal_complex_type() != nullptr) { + auto literalType = ANY_CAST( + ::substrait::proto::Type, + visitLiteral_complex_type(ctx->literal_complex_type())); + *expr.mutable_scalar_function()->mutable_output_type() = literalType; + } return expr; } @@ -821,23 +1151,18 @@ std::any SubstraitPlanRelationVisitor::visitExpressionColumn( ANY_CAST(std::shared_ptr, currentRelationScope_->blob); std::string symbolName = ctx->getText(); - auto currentFieldNumber = std::find_if( - relationData->fieldReferences.begin(), - relationData->fieldReferences.end(), - [&](auto ref) { return (ref->name == symbolName); }); + int32_t fieldReference = + findFieldReferenceByName(ctx->getStart(), relationData, symbolName); ::substrait::proto::Expression expr; - if (currentFieldNumber != relationData->fieldReferences.end()) { - int32_t fieldReference = static_cast( - (currentFieldNumber - relationData->fieldReferences.begin()) & - std::numeric_limits::max()); + if (fieldReference != -1) { expr.mutable_selection() ->mutable_direct_reference() ->mutable_struct_field() ->set_field(fieldReference); + // TODO -- Update the following when non-direct references are implemented. + expr.mutable_selection()->mutable_root_reference(); } - - visitChildren(ctx); return expr; } @@ -926,7 +1251,10 @@ std::any SubstraitPlanRelationVisitor::visitStruct_literal( std::any SubstraitPlanRelationVisitor::visitColumn_name( SubstraitPlanParser::Column_nameContext* ctx) { - return visitChildren(ctx); + auto relationData = + ANY_CAST(std::shared_ptr, currentRelationScope_->blob); + return findFieldReferenceByName( + ctx->getStart(), relationData, ctx->getText()); } ::substrait::proto::Expression_Literal @@ -1562,4 +1890,133 @@ int32_t SubstraitPlanRelationVisitor::visitSortDirection( return ::substrait::proto::SortField::SORT_DIRECTION_UNSPECIFIED; } +void SubstraitPlanRelationVisitor::addExpressionsToSchema( + std::shared_ptr& relationData) { + const auto& relation = relationData->relation; + switch (relation.rel_type_case()) { + case ::substrait::proto::Rel::kProject: + for (const auto& expr : relation.project().expressions()) { + if (expr.selection().direct_reference().has_struct_field()) { + if (expr.selection().direct_reference().struct_field().field() < + relationData->fieldReferences.size()) { + relationData->generatedFieldReferences.push_back( + relationData->fieldReferences[expr.selection() + .direct_reference() + .struct_field() + .field()]); + } + } else { + const auto& uniqueName = + symbolTable_->getUniqueName(kIntermediateNodeName); + auto newSymbol = symbolTable_->defineSymbol( + uniqueName, + PROTO_LOCATION(expr), + SymbolType::kUnknown, + std::nullopt, + std::nullopt); + relationData->generatedFieldReferences.push_back(newSymbol); + } + } + break; + default: + // Only project and aggregate relations affect the output mapping. + break; + } +} + +std::string SubstraitPlanRelationVisitor::fullyQualifiedReference( + const SymbolInfo* fieldReference) { + for (const auto& symbol : symbolTable_->getSymbols()) { + if (symbol->type == SymbolType::kSchema && + symbol->location == fieldReference->location) { + auto fqn = symbol->name + "." + fieldReference->name; + return fqn; + } + } + // Shouldn't happen, but return no schema if we can't find one. + return fieldReference->name; +} + +int SubstraitPlanRelationVisitor::findFieldReferenceByName( + antlr4::Token* token, + std::shared_ptr& relationData, + const std::string& name) { + auto fieldReferencesSize = relationData->fieldReferences.size(); + + auto generatedField = std::find_if( + relationData->generatedFieldReferences.rbegin(), + relationData->generatedFieldReferences.rend(), + [&](auto ref) { + return (!ref->alias.empty() && ref->alias == name || ref->name == name); + }); + if (generatedField != relationData->generatedFieldReferences.rend()) { + auto fieldPlacement = + generatedField - relationData->generatedFieldReferences.rbegin(); + return static_cast( + (fieldReferencesSize + relationData->generatedFieldReferences.size() - + fieldPlacement - 1) & + std::numeric_limits::max()); + } + + auto field = std::find_if( + relationData->fieldReferences.rbegin(), + relationData->fieldReferences.rend(), + [&](auto ref) { + return ( + !ref->alias.empty() && ref->alias == name || ref->name == name || + fullyQualifiedReference(ref) == name); + }); + + if (field != relationData->fieldReferences.rend()) { + auto fieldPlacement = field - relationData->fieldReferences.rbegin(); + return static_cast( + (fieldReferencesSize - fieldPlacement - 1) & + std::numeric_limits::max()); + } + + errorListener_->addError(token, "Reference " + name + " does not exist."); + return -1; +} + +void SubstraitPlanRelationVisitor::applyOutputMappingToSchema( + antlr4::Token* token, + RelationType relationType, + std::shared_ptr& relationData) { + auto common = findCommonRelation(relationType, &relationData->relation); + if (common == nullptr) { + return; + } + if (common->emit().output_mapping_size() == 0) { + common->mutable_direct(); + } else { + if (!relationData->outputFieldReferences.empty()) { + // TODO -- Add support for aggregate relations. + errorListener_->addError( + token, "Aggregate relations do not yet support emit sections."); + return; + } + for (auto mapping : common->emit().output_mapping()) { + auto fieldReferencesSize = relationData->fieldReferences.size(); + if (mapping < fieldReferencesSize) { + relationData->outputFieldReferences.push_back( + relationData->fieldReferences[mapping]); + } else if ( + mapping < + fieldReferencesSize + relationData->generatedFieldReferences.size()) { + relationData->outputFieldReferences.push_back( + relationData + ->generatedFieldReferences[mapping - fieldReferencesSize]); + } else { + errorListener_->addError( + token, + "Field #" + std::to_string(mapping) + " requested but only " + + std::to_string( + fieldReferencesSize + + relationData->generatedFieldReferences.size()) + + " are available."); + } + } + } +} + } // namespace io::substrait::textplan diff --git a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h index 95a0f335..8844819f 100644 --- a/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h +++ b/src/substrait/textplan/parser/SubstraitPlanRelationVisitor.h @@ -18,6 +18,8 @@ class Type_Struct; namespace io::substrait::textplan { +class RelationData; + class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { public: SubstraitPlanRelationVisitor( @@ -38,6 +40,10 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { std::any visitRelation(SubstraitPlanParser::RelationContext* ctx) override; + // visitRelationDetail is a new method delegating to the methods below. + std::any visitRelationDetail( + SubstraitPlanParser::Relation_detailContext* ctx); + std::any visitRelation_filter_behavior( SubstraitPlanParser::Relation_filter_behaviorContext* ctx) override; @@ -56,6 +62,12 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { std::any visitRelationMeasure( SubstraitPlanParser::RelationMeasureContext* ctx) override; + std::any visitRelationJoinType( + SubstraitPlanParser::RelationJoinTypeContext* ctx) override; + + std::any visitRelationEmit( + SubstraitPlanParser::RelationEmitContext* ctx) override; + int32_t visitAggregationInvocation(SubstraitPlanParser::IdContext* ctx); int32_t visitAggregationPhase(SubstraitPlanParser::IdContext* ctx); @@ -69,6 +81,9 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { std::any visitRelationSort( SubstraitPlanParser::RelationSortContext* ctx) override; + std::any visitRelationCount( + SubstraitPlanParser::RelationCountContext* ctx) override; + // visitExpression is a new method delegating to the methods below. std::any visitExpression(SubstraitPlanParser::ExpressionContext* ctx); @@ -167,6 +182,20 @@ class SubstraitPlanRelationVisitor : public SubstraitPlanTypeVisitor { const antlr4::tree::TerminalNode* node, const std::string& str); + void addExpressionsToSchema(std::shared_ptr& relationData); + + void applyOutputMappingToSchema( + antlr4::Token* token, + RelationType relationType, + std::shared_ptr& relationData); + + std::string fullyQualifiedReference(const SymbolInfo* fieldReference); + + int findFieldReferenceByName( + antlr4::Token* token, + std::shared_ptr& relationData, + const std::string& name); + const SymbolInfo* currentRelationScope_{nullptr}; // Not owned. }; diff --git a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp index f57d52eb..e3d47c5a 100644 --- a/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp +++ b/src/substrait/textplan/parser/SubstraitPlanVisitor.cpp @@ -8,13 +8,16 @@ #include "substrait/textplan/Any.h" #include "substrait/textplan/Finally.h" #include "substrait/textplan/Location.h" +#include "substrait/textplan/StringManipulation.h" #include "substrait/textplan/StructuredSymbolData.h" #include "substrait/textplan/SymbolTable.h" -#include "substrait/type/Type.h" namespace io::substrait::textplan { +namespace { + const std::string kRootName{"root"}; +const std::string kRootNames{"root.names"}; // Removes leading and trailing quotation marks. std::string extractFromString(std::string s) { @@ -30,12 +33,70 @@ std::string extractFromString(std::string s) { return s; } +uint64_t parseUnsignedInteger(const std::string& text) { + return std::stoul(text); +} + +std::string parseString(std::string text) { + // First remove the surrounding quote marks. + std::string str; + if (startsWith(text, "```")) { + str = text.substr(3, text.length() - 6); + } else if (startsWith(text, "``")) { + str = text.substr(2, text.length() - 4); + } else if (text[0] == '"' || text[0] == '`') { + str = text.substr(1, text.length() - 2); + } else { + str = text; + } + + // Perform escapes if necessary. + std::string resultStr; + if (startsWith(text, "`")) { + // Don't perform escapes on raw strings. + resultStr = str; + } else { + // TODO -- Escape the text as in SubstraitPlanRelationVisitor::escapeText. + resultStr = str; + } + + return resultStr; +} + +} // namespace + // NOLINTBEGIN(readability-identifier-naming) // NOLINTBEGIN(readability-convert-member-functions-to-static) std::any SubstraitPlanVisitor::visitPlan( SubstraitPlanParser::PlanContext* ctx) { - return visitChildren(ctx); + // First visit the schema, source, and extension space definitions. + for (auto detail : ctx->plan_detail()) { + if (detail->schema_definition() != nullptr || + detail->source_definition() != nullptr || + detail->extensionspace() != nullptr) { + visitPlan_detail(detail); + } + } + // Then visit the pipelines. + for (auto detail : ctx->plan_detail()) { + if (detail->pipelines() != nullptr) { + visitPlan_detail(detail); + } + } + // Next visit the relations. + for (auto detail : ctx->plan_detail()) { + if (detail->relation() != nullptr) { + visitPlan_detail(detail); + } + } + // And finally visit the root. + for (auto detail : ctx->plan_detail()) { + if (detail->root_relation() != nullptr) { + visitPlan_detail(detail); + } + } + return defaultResult(); } std::any SubstraitPlanVisitor::visitPlan_detail( @@ -73,7 +134,8 @@ std::any SubstraitPlanVisitor::visitExtensionspace( // Update the contained functions to belong in this space. for (auto func : ctx->function()) { - auto* funcSymbol = symbolTable_->lookupSymbolByLocation(Location(func)); + auto* funcSymbol = symbolTable_->lookupSymbolByLocationAndType( + Location(func), SymbolType::kFunction); auto functionData = ANY_CAST(std::shared_ptr, funcSymbol->blob); functionData->extensionUriReference = thisSpace; @@ -131,20 +193,21 @@ std::any SubstraitPlanVisitor::visitSignature( std::any SubstraitPlanVisitor::visitSchema_definition( SubstraitPlanParser::Schema_definitionContext* ctx) { - symbolTable_->defineSymbol( + auto schemaSymbol = symbolTable_->defineSymbol( ctx->id()->getText(), Location(ctx), SymbolType::kSchema, defaultResult(), defaultResult()); - // Mark all of the schema items so we can find the ones related to this + // Mark all the schema items so that we can find the ones related to this // schema. for (const auto& item : ctx->schema_item()) { auto symbol = ANY_CAST(SymbolInfo*, visitSchema_item(item)); if (symbol == nullptr) { continue; } + symbol->schema = schemaSymbol; symbol->location = Location(ctx); } @@ -153,33 +216,49 @@ std::any SubstraitPlanVisitor::visitSchema_definition( std::any SubstraitPlanVisitor::visitSchema_item( SubstraitPlanParser::Schema_itemContext* ctx) { - return symbolTable_->defineSymbol( - ctx->id()->getText(), + auto symbol = symbolTable_->defineSymbol( + ctx->id(0)->getText(), Location(ctx), SymbolType::kSchemaColumn, defaultResult(), visitLiteral_complex_type(ctx->literal_complex_type())); + return symbol; } std::any SubstraitPlanVisitor::visitRoot_relation( SubstraitPlanParser::Root_relationContext* ctx) { + if (symbolTable_->lookupSymbolByName(kRootNames) != nullptr) { + errorListener_->addError( + ctx->getStart(), "A root relation was already defined."); + return nullptr; + } auto prevRoot = symbolTable_->lookupSymbolByName(kRootName); if (prevRoot != nullptr) { - if (prevRoot->type == SymbolType::kRoot) { - errorListener_->addError( - ctx->getStart(), "A root relation was already defined."); - } else { - errorListener_->addError( - ctx->getStart(), "A relation named root was already defined."); - } + errorListener_->addError( + ctx->getStart(), "A relation named root was already defined."); return nullptr; } + + // First creation the relation information for this node. + auto relationData = std::make_shared(); + symbolTable_->defineSymbol( + kRootName, + Location(ctx), + SymbolType::kRelation, + SymbolType::kRoot, + relationData); + + // Now create the name related information. std::vector names; for (const auto& id : ctx->id()) { names.push_back(id->getText()); } symbolTable_->defineSymbol( - kRootName, Location(ctx), SymbolType::kRoot, SourceType::kUnknown, names); + kRootNames, + Location(ctx), + SymbolType::kRoot, + SourceType::kUnknown, + names); return nullptr; } @@ -251,13 +330,30 @@ std::any SubstraitPlanVisitor::visitRelation_type( return RelationType::kUnknown; } +SourceType getSourceType(SubstraitPlanParser::Read_typeContext* ctx) { + if (dynamic_cast(ctx) != nullptr) { + return SourceType::kLocalFiles; + } else if ( + dynamic_cast(ctx) != nullptr) { + return SourceType::kVirtualTable; + } else if ( + dynamic_cast(ctx) != nullptr) { + return SourceType::kNamedTable; + } else if ( + dynamic_cast(ctx) != + nullptr) { + return SourceType::kExtensionTable; + } + return SourceType::kUnknown; +} + std::any SubstraitPlanVisitor::visitSource_definition( SubstraitPlanParser::Source_definitionContext* ctx) { symbolTable_->defineSymbol( ctx->read_type()->children[1]->getText(), Location(ctx), SymbolType::kSource, - defaultResult(), + getSourceType(ctx->read_type()), defaultResult()); return visitChildren(ctx); } @@ -310,19 +406,6 @@ std::any SubstraitPlanVisitor::visitExpressionConstant( std::any SubstraitPlanVisitor::visitExpressionColumn( SubstraitPlanParser::ExpressionColumnContext* ctx) { - auto relationData = - ANY_CAST(std::shared_ptr, currentRelationScope_->blob); - std::string column_name = ctx->column_name()->getText(); - auto symbol = symbolTable_->lookupSymbolByName(column_name); - if (symbol == nullptr) { - symbol = symbolTable_->defineSymbol( - column_name, - Location(ctx), - SymbolType::kField, - std::nullopt, - std::nullopt); - relationData->fieldReferences.push_back(symbol); - } return visitChildren(ctx); } @@ -399,12 +482,36 @@ std::any SubstraitPlanVisitor::visitRelationJoinType( std::any SubstraitPlanVisitor::visitFile_location( SubstraitPlanParser::File_locationContext* ctx) { + auto symbol = + symbolTable_->lookupSymbolByName(ctx->parent->parent->getText()); + auto item = ANY_CAST( + std::shared_ptr<::substrait::proto::ReadRel_LocalFiles_FileOrFiles>, + symbol->blob); + if (ctx->URI_FILE() != nullptr) { + item->set_uri_file(parseString(ctx->STRING()->getText())); + } + // symbol->blob.swap(item); return visitChildren(ctx); } std::any SubstraitPlanVisitor::visitFile_detail( SubstraitPlanParser::File_detailContext* ctx) { - return visitChildren(ctx); + auto symbol = symbolTable_->lookupSymbolByName(ctx->parent->getText()); + auto item = ANY_CAST( + std::shared_ptr<::substrait::proto::ReadRel_LocalFiles_FileOrFiles>, + symbol->blob); + if (ctx->PARTITION_INDEX() != nullptr) { + item->set_partition_index(parseUnsignedInteger(ctx->NUMBER()->getText())); + } else if (ctx->START() != nullptr) { + item->set_start(parseUnsignedInteger(ctx->NUMBER()->getText())); + } else if (ctx->LENGTH() != nullptr) { + item->set_length(parseUnsignedInteger(ctx->NUMBER()->getText())); + } else if (ctx->ORC() != nullptr) { + item->mutable_orc(); + } else { + return visitChildren(ctx); + } + return defaultResult(); } std::any SubstraitPlanVisitor::visitFile( @@ -415,37 +522,36 @@ std::any SubstraitPlanVisitor::visitFile( std::any SubstraitPlanVisitor::visitLocal_files_detail( SubstraitPlanParser::Local_files_detailContext* ctx) { for (const auto& f : ctx->file()) { + ::substrait::proto::ReadRel_LocalFiles_FileOrFiles item; symbolTable_->defineSymbol( - f->getText(), + f->getText(), // We use all the details to create a unique name. PARSER_LOCATION(ctx->parent->parent), // The source we belong to. SymbolType::kSourceDetail, defaultResult(), - defaultResult()); + std::make_shared<::substrait::proto::ReadRel_LocalFiles_FileOrFiles>( + item)); + visitFile(f); } return nullptr; } std::any SubstraitPlanVisitor::visitLocalFiles( SubstraitPlanParser::LocalFilesContext* ctx) { - // TODO -- Once we switch over to SourceData update our parent's subtype. return visitChildren(ctx); } std::any SubstraitPlanVisitor::visitVirtualTable( SubstraitPlanParser::VirtualTableContext* ctx) { - // TODO -- Once we switch over to SourceData update our parent's subtype. return visitChildren(ctx); } std::any SubstraitPlanVisitor::visitNamedTable( SubstraitPlanParser::NamedTableContext* ctx) { - // TODO -- Once we switch over to SourceData update our parent's subtype. return visitChildren(ctx); } std::any SubstraitPlanVisitor::visitExtensionTable( SubstraitPlanParser::ExtensionTableContext* ctx) { - // TODO -- Once we switch over to SourceData update our parent's subtype. return visitChildren(ctx); } diff --git a/src/substrait/textplan/parser/data/provided_sample1.splan b/src/substrait/textplan/parser/data/provided_sample1.splan index ed3a0dd3..0d887505 100644 --- a/src/substrait/textplan/parser/data/provided_sample1.splan +++ b/src/substrait/textplan/parser/data/provided_sample1.splan @@ -2,6 +2,11 @@ pipelines { read -> project -> root; } +read relation read { + base_schema schema; + source named; +} + project relation project { expression r_regionkey; expression r_name; @@ -11,11 +16,6 @@ project relation project { expression concat(r_name, r_name); } -read relation read { - base_schema schema; - source named; -} - schema schema { r_regionkey i32; r_name string; diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 index 1ad91127..355a38c7 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanLexer.g4 @@ -26,6 +26,7 @@ tokens { SPACES } EXTENSION_SPACE: 'EXTENSION_SPACE' -> mode(EXTENSIONS); FUNCTION: 'FUNCTION'; AS: 'AS'; +NAMED: 'NAMED'; SCHEMA: 'SCHEMA'; RELATION: 'RELATION'; PIPELINES: 'PIPELINES'; @@ -43,6 +44,7 @@ SORT: 'SORT'; BY: 'BY'; COUNT: 'COUNT'; TYPE: 'TYPE'; +EMIT: 'EMIT'; VIRTUAL_TABLE: 'VIRTUAL_TABLE'; LOCAL_FILES: 'LOCAL_FILES'; diff --git a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 index 145d7262..3613ce8a 100644 --- a/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 +++ b/src/substrait/textplan/parser/grammar/SubstraitPlanParser.g4 @@ -64,28 +64,29 @@ relation_filter_behavior // TODO -- Can the type be determined automatically from the function definition? // TODO -- Consider moving the run phase to an optional third detail line. measure_detail - : MEASURE expression (ARROW literal_complex_type)? (ATSIGN id)? SEMICOLON + : MEASURE expression (ARROW literal_complex_type)? (ATSIGN id)? (NAMED id)? SEMICOLON | FILTER expression SEMICOLON | INVOCATION id SEMICOLON | sort_field ; relation_detail - : COMMON SEMICOLON # relationCommon - | BASE_SCHEMA id SEMICOLON # relationUsesSchema - | relation_filter_behavior? FILTER expression SEMICOLON # relationFilter - | EXPRESSION expression SEMICOLON (AS id)? # relationExpression - | ADVANCED_EXTENSION SEMICOLON # relationAdvancedExtension - | source_reference SEMICOLON # relationSourceReference - | GROUPING expression SEMICOLON # relationGrouping - | MEASURE LEFTBRACE measure_detail* RIGHTBRACE # relationMeasure - | sort_field # relationSort - | COUNT NUMBER SEMICOLON # relationCount - | TYPE id SEMICOLON # relationJoinType + : COMMON SEMICOLON # relationCommon + | BASE_SCHEMA id SEMICOLON # relationUsesSchema + | relation_filter_behavior? FILTER expression SEMICOLON # relationFilter + | EXPRESSION expression (NAMED id)? SEMICOLON # relationExpression + | ADVANCED_EXTENSION SEMICOLON # relationAdvancedExtension + | source_reference SEMICOLON # relationSourceReference + | GROUPING expression SEMICOLON # relationGrouping + | MEASURE LEFTBRACE measure_detail* RIGHTBRACE # relationMeasure + | sort_field # relationSort + | COUNT NUMBER SEMICOLON # relationCount + | TYPE id SEMICOLON # relationJoinType + | EMIT column_name SEMICOLON # relationEmit ; expression - : id LEFTPAREN (expression COMMA?)* RIGHTPAREN # expressionFunctionUse + : id LEFTPAREN (expression COMMA?)* RIGHTPAREN (ARROW literal_complex_type)? # expressionFunctionUse | constant # expressionConstant | column_name # expressionColumn | expression AS literal_complex_type # expressionCast @@ -131,7 +132,7 @@ struct_literal ; column_name - : id + : (id PERIOD)? id ; source_reference @@ -172,7 +173,7 @@ schema_definition ; schema_item - : id literal_complex_type SEMICOLON + : id literal_complex_type (NAMED id)? SEMICOLON ; source_definition @@ -206,11 +207,11 @@ signature : id ; -// List keywords here to make them not reserved. id : simple_id (UNDERSCORE+ simple_id)* ; +// List keywords here to make them not reserved. simple_id : IDENTIFIER | FILTER @@ -223,4 +224,6 @@ simple_id | GROUPING | COUNT | TYPE + | EMIT + | NAMED ; diff --git a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp index e3fc2f68..307835f5 100644 --- a/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp +++ b/src/substrait/textplan/parser/tests/TextPlanParserTest.cpp @@ -200,37 +200,57 @@ std::vector getTestCases() { }, { "test5-project-relation", - R"(extension_space blah.yaml { + R"(pipelines { + read -> myproject -> root; + } + + extension_space blah.yaml { function add:i8 as add; function subtract:i8 AS subtract; function concat:str AS concat; } + read relation read { + base_schema schema; + } + project relation myproject { expression r_regionkey; expression r_name; expression r_comment; - expression add(r_regionkey, 1_i8); - expression subtract(r_regionkey, 1_i8); - expression concat(r_name, r_name); + expression add(r_regionkey, 1_i8)->i8; + expression subtract(r_regionkey, 1_i8)->i8; + expression concat(r_name, r_name)->str; + } + + schema schema { + r_regionkey i32; + r_name string; + r_comment string?; })", AllOf( - HasSymbolsWithTypes({"myproject"}, {SymbolType::kRelation}), + HasSymbolsWithTypes( + {"read", "myproject", "root"}, {SymbolType::kRelation}), + HasErrors({}), WhenSerialized(EqSquashingWhitespace( - R"(project relation myproject { - expression r_regionkey; - expression r_name; - expression r_comment; - expression add(r_regionkey, 1_i8); - expression subtract(r_regionkey, 1_i8); - expression concat(r_name, r_name); - } + R"(pipelines { + read -> myproject -> root; + } + + project relation myproject { + expression schema.r_regionkey; + expression schema.r_name; + expression schema.r_comment; + expression add(schema.r_regionkey, 1_i8)->i8; + expression subtract(schema.r_regionkey, 1_i8)->i8; + expression concat(schema.r_name, schema.r_name)->string; + } - extension_space blah.yaml { - function add:i8 as add; - function concat:str as concat; - function subtract:i8 as subtract; - })")), + extension_space blah.yaml { + function add:i8 as add; + function concat:str as concat; + function subtract:i8 as subtract; + })")), AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(extension_uris { extension_uri_anchor: 1 uri: "blah.yaml" @@ -250,6 +270,19 @@ std::vector getTestCases() { name: "concat:str" } } relations { root { input { project { + common { direct { } } + input { + read { common { direct { } } + base_schema { + names: "r_regionkey" + names: "r_name" + names: "r_comment" + struct { types { i32 { + nullability: NULLABILITY_REQUIRED } } + types { string { nullability: NULLABILITY_REQUIRED } } + types { string { nullability: NULLABILITY_NULLABLE } } + nullability: NULLABILITY_REQUIRED } } } + } expressions { selection { direct_reference { @@ -257,6 +290,7 @@ std::vector getTestCases() { field: 0 } } + root_reference: { } } } expressions { @@ -265,7 +299,7 @@ std::vector getTestCases() { struct_field { field: 1 } - } + } root_reference: { } } } expressions { @@ -274,22 +308,28 @@ std::vector getTestCases() { struct_field { field: 2 } - } + } root_reference: { } } } expressions { scalar_function { function_reference: 0 arguments { value { selection { - direct_reference { struct_field { } } } } } - arguments { value { literal { i8: 1 } } } } } + direct_reference { struct_field { } } root_reference: { } } } } + arguments { value { literal { i8: 1 } } } + output_type { + i8 { nullability: NULLABILITY_REQUIRED} } } } expressions { scalar_function { function_reference: 1 arguments { value { selection { - direct_reference { struct_field { } } } } } - arguments { value { literal { i8: 1 } } } } } + direct_reference { struct_field { } } root_reference: { } } } } + arguments { value { literal { i8: 1 } } } + output_type { + i8 { nullability: NULLABILITY_REQUIRED } } } } expressions { scalar_function { function_reference: 2 arguments { value { selection { - direct_reference { struct_field { field: 1 } } } } } + direct_reference { struct_field { field: 1 } } root_reference: { } } } } arguments { value { selection { direct_reference { - struct_field { field: 1 } } } } } } } + struct_field { field: 1 } } root_reference: { } } } } + output_type { + string { nullability: NULLABILITY_REQUIRED } } } } } } } })"))), }, { @@ -535,6 +575,7 @@ std::vector getTestCases() { })", AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { + common { direct { } } expressions { literal { timestamp: 946684800 } } expressions { literal { timestamp: 946684800 } } expressions { literal { date: 18616 } } @@ -590,6 +631,7 @@ std::vector getTestCases() { })", AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { + common { direct { } } expressions { literal { list { values { string: "a" } values { string: "b" } @@ -623,20 +665,17 @@ std::vector getTestCases() { R"(project relation literalexamples { expression null_list; expression null_list>; - expression null_list>?; expression null_list?>; })", AsBinaryPlan((EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { + common { direct { } } expressions { literal { null { list { type { string { nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } expressions { literal { null { list { type { list { type { string { nullability: NULLABILITY_NULLABLE } } nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } - expressions { literal { null { list { type { - list { type { string { nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_REQUIRED } } - nullability: NULLABILITY_NULLABLE } } } } expressions { literal { null { list { type { list { type { string { nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_REQUIRED } } nullability: NULLABILITY_NULLABLE } } } } @@ -650,6 +689,7 @@ std::vector getTestCases() { })", AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { + common { direct { } } expressions { literal { map { key_values { key { i16: 42 } value { string: "life" } } key_values { key { i16: 32 } value { string: "everything" } } @@ -676,6 +716,7 @@ std::vector getTestCases() { })", AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { + common { direct { } } expressions { literal { struct { fields { string: "a" } fields { struct { @@ -776,7 +817,7 @@ std::vector getTestCases() { expression {123}_map; })", HasErrors({ - "2:38 → extraneous input '?' expecting ';'", + "2:38 → extraneous input '?' expecting {'NAMED', ';'}", "3:26 → Unable to recognize requested type.", "4:26 → Unable to recognize requested type.", "5:26 → Maps require both a key and a value type.", @@ -801,6 +842,7 @@ std::vector getTestCases() { HasErrors({}), AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations { root { input { project { + common { direct { } } expressions { cast { type { i32 { nullability: NULLABILITY_REQUIRED } } input { literal { i8: 123 } } } } @@ -836,8 +878,14 @@ std::vector getTestCases() { "test13-bad-functions", R"(extension_space blah.yaml { function sum: as sum; + function sum as sum; + function sum; })", - HasErrors({"2:12 → Functions should have an associated type."}), + HasErrors({ + "3:25 → missing ':' at 'as'", + "4:24 → missing ':' at ';'", + "2:12 → Functions should have an associated type.", + }), }, { "test14-three-node-pipeline-with-fields", @@ -917,10 +965,13 @@ std::vector getTestCases() { root: { input: { join: { + common { direct { } } left: { join: { + common { direct { } } left: { read: { + common { direct { } } base_schema { names: "order_id" names: "product_id" @@ -928,19 +979,23 @@ std::vector getTestCases() { struct { types { i32 { nullability: NULLABILITY_REQUIRED } } types { i32 { nullability: NULLABILITY_REQUIRED } } - types { i64 { nullability: NULLABILITY_REQUIRED } } } + types { i64 { nullability: NULLABILITY_REQUIRED } } + nullability: NULLABILITY_REQUIRED } } named_table { names: "#1" } } } right: { read: { + common { direct { } } base_schema { names: "product_id" names: "cost" struct { types { i32 { nullability: NULLABILITY_REQUIRED } } - types { fp32 { nullability: NULLABILITY_REQUIRED } } } + types { fp32 { nullability: NULLABILITY_REQUIRED } } + nullability: NULLABILITY_REQUIRED + } } named_table { names: "#2" } } @@ -952,6 +1007,7 @@ std::vector getTestCases() { field: 1 } } + root_reference: { } } } post_join_filter: { @@ -961,18 +1017,21 @@ std::vector getTestCases() { field: 2 } } + root_reference: { } } } } } right: { read: { + common { direct { } } base_schema { names: "company" names: "order_id" struct { types { string { nullability: NULLABILITY_REQUIRED } } types { i32 { nullability: NULLABILITY_REQUIRED } } + nullability: NULLABILITY_REQUIRED } } named_table { names: "#3" } @@ -985,6 +1044,7 @@ std::vector getTestCases() { field: 6 } } + root_reference: { } } } } @@ -999,9 +1059,9 @@ std::vector getTestCases() { // TODO -- Replace this error message with something user-friendly. HasErrors({ "1:0 → extraneous input 'relation' expecting {, " - "'EXTENSION_SPACE', 'SCHEMA', 'PIPELINES', 'FILTER', " - "'GROUPING', 'MEASURE', 'SORT', 'COUNT', 'TYPE', 'SOURCE', " - "'ROOT', 'NULL', IDENTIFIER}", + "'EXTENSION_SPACE', 'NAMED', 'SCHEMA', 'PIPELINES', 'FILTER', " + "'GROUPING', 'MEASURE', 'SORT', 'COUNT', 'TYPE', 'EMIT', " + "'SOURCE', 'ROOT', 'NULL', IDENTIFIER}", "1:24 → mismatched input '{' expecting 'RELATION'", "1:9 → Unrecognized relation type: notyperelation", }), @@ -1015,7 +1075,7 @@ std::vector getTestCases() { { "test17-pipelines-with-relations", R"(pipelines { - root -> project -> read; + read -> project -> root; } read relation read { @@ -1025,6 +1085,12 @@ std::vector getTestCases() { project relation project { expression r_regionkey; + } + + schema schemaone { + r_regionkey i32; + r_name string; + r_comment string?; })", AllOf( HasSymbolsWithTypes( @@ -1035,7 +1101,7 @@ std::vector getTestCases() { { "test18-root-and-read", R"(pipelines { - root -> read; + read -> root; } read relation read { @@ -1050,9 +1116,65 @@ std::vector getTestCases() { })", AsBinaryPlan(EqualsProto<::substrait::proto::Plan>( R"(relations: { - root { names: "apple" } + root { names: "apple" input { read { common { direct { } } } } } })")), }, + { + "test19-emit", + R"(pipelines { + read -> project -> root; + } + + read relation read { + base_schema schemaone; + } + + project relation project { + expression r_region_key; + + emit r_region_key; + } + + schema schemaone { + r_region_key i32; + })", + AllOf( + HasSymbolsWithTypes( + {"read", "project", "root"}, {SymbolType::kRelation}), + HasErrors({}), + AsBinaryPlan(EqualsProto<::substrait::proto::Plan>(R"( + relations { + root { + input { + project { + common { + emit { + output_mapping: 1 + } + } + input { + read: { + common { direct { } } + base_schema { + names: "r_region_key" + struct { types { i32 { + nullability: NULLABILITY_REQUIRED } } + nullability: NULLABILITY_REQUIRED } } } + } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference: { } + } + } + } + } + } + })"))), + }, }; return cases; } diff --git a/src/substrait/textplan/tests/RoundtripTest.cpp b/src/substrait/textplan/tests/RoundtripTest.cpp index 4cf2e3ea..3b22071c 100644 --- a/src/substrait/textplan/tests/RoundtripTest.cpp +++ b/src/substrait/textplan/tests/RoundtripTest.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -25,6 +26,11 @@ using ::testing::AllOf; namespace io::substrait::textplan { namespace { +bool endsWith(const std::string& haystack, const std::string& needle) { + return haystack.size() > needle.size() && + haystack.substr(haystack.size() - needle.size()) == needle; +} + std::string addLineNumbers(const std::string& text) { std::stringstream input{text}; std::stringstream result; @@ -53,7 +59,8 @@ std::vector getTestCases() { for (auto const& dirEntry : std::filesystem::recursive_directory_iterator{testDataPath}) { std::string pathName = dirEntry.path(); - if (pathName.substr(pathName.length() - 5) == ".json") { + if (endsWith(pathName, ".json") && + !endsWith(pathName, "q6_first_stage.json")) { filenames.push_back(pathName); } } @@ -71,15 +78,16 @@ TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { auto plan = *planOrErrors; auto textResult = parseBinaryPlan(plan); - auto symbols = textResult.getSymbolTable().getSymbols(); + auto textSymbols = textResult.getSymbolTable().getSymbols(); std::string outputText = SymbolTablePrinter::outputToText(textResult.getSymbolTable()); ASSERT_THAT(textResult, AllOf(ParsesOk(), HasErrors({}))) << std::endl - << "Intermediate result:" << std::endl - << addLineNumbers(outputText) << std::endl; + << "Initial result:" << std::endl + << addLineNumbers(outputText) << std::endl + << textResult.getSymbolTable().toDebugString() << std::endl; auto stream = loadTextString(outputText); auto result = parseStream(stream); @@ -93,7 +101,8 @@ TEST_P(RoundTripBinaryToTextFixture, RoundTrip) { ParsesOk(), HasErrors({}), AsBinaryPlan(EqualsProto(normalizedPlan)))) << std::endl << "Intermediate result:" << std::endl - << addLineNumbers(outputText); + << addLineNumbers(outputText) << std::endl + << result.getSymbolTable().toDebugString() << std::endl; } INSTANTIATE_TEST_SUITE_P(