Skip to content

Commit

Permalink
feat: extension spaces, functions, aggregation relations to binary (#67)
Browse files Browse the repository at this point in the history
* Features
   * Extension spaces and functions are now emitted to the binary plan
* Details such as invocation, measures, groupings now implemented for
aggregation relations and sorts for sort relations
* Fixes
   * Type conversion to text standardized for schemas and extensions
* Changes lookupSymbolByLocation to return a pointer as the named
lookup.
  • Loading branch information
EpsilonPrime committed Jun 12, 2023
1 parent 98abe4b commit fee0829
Show file tree
Hide file tree
Showing 26 changed files with 1,220 additions and 608 deletions.
3 changes: 1 addition & 2 deletions src/substrait/textplan/ParseResult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

#include "substrait/textplan/ParseResult.h"

#include <iosfwd>
#include <iostream>
#include <sstream>

namespace io::substrait::textplan {

Expand Down
6 changes: 3 additions & 3 deletions src/substrait/textplan/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ const SymbolInfo* SymbolTable::lookupSymbolByName(
return symbols_[itr->second].get();
}

const SymbolInfo& SymbolTable::lookupSymbolByLocation(
const SymbolInfo* SymbolTable::lookupSymbolByLocation(
const Location& location) const {
auto itr = symbolsByLocation_.find(location);
if (itr == symbolsByLocation_.end()) {
return SymbolInfo::kUnknown;
return nullptr;
}
return *symbols_[itr->second];
return symbols_[itr->second].get();
}

const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type)
Expand Down
2 changes: 1 addition & 1 deletion src/substrait/textplan/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class SymbolTable {
[[nodiscard]] const SymbolInfo* lookupSymbolByName(
const std::string& name) const;

[[nodiscard]] const SymbolInfo& lookupSymbolByLocation(
[[nodiscard]] const SymbolInfo* lookupSymbolByLocation(
const Location& location) const;

[[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type)
Expand Down
164 changes: 87 additions & 77 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,83 +74,9 @@ void localFileToText(
}

std::string typeToText(const ::substrait::proto::Type& type) {
switch (type.kind_case()) {
case ::substrait::proto::Type::kBool:
if (type.bool_().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "bool?";
}
return "bool";
case ::substrait::proto::Type::kI8:
if (type.i8().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "i8?";
}
return "i8";
case ::substrait::proto::Type::kI16:
if (type.i16().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "i16?";
}
return "i16";
case ::substrait::proto::Type::kI32:
if (type.i32().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "i32?";
}
return "i32";
case ::substrait::proto::Type::kI64:
if (type.i64().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "i64?";
}
return "i64";
case ::substrait::proto::Type::kFp32:
if (type.fp32().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "fp32?";
}
return "fp32";
case ::substrait::proto::Type::kFp64:
if (type.fp64().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "fp64?";
}
return "fp64";
case ::substrait::proto::Type::kString:
if (type.string().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "string?";
}
return "string";
case ::substrait::proto::Type::kDecimal:
if (type.string().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "decimal?";
}
return "decimal";
case ::substrait::proto::Type::kVarchar:
if (type.varchar().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "varchar?";
}
return "varchar";
case ::substrait::proto::Type::kFixedChar:
if (type.fixed_char().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "fixedchar?";
}
return "fixedchar";
case ::substrait::proto::Type::kDate:
if (type.date().nullability() ==
::substrait::proto::Type::NULLABILITY_NULLABLE) {
return "date?";
}
return "date";
case ::substrait::proto::Type::KIND_NOT_SET:
default:
return "UNSUPPORTED_TYPE";
}
SymbolTable symbolTable;
PlanPrinterVisitor visitor(symbolTable);
return visitor.typeToText(type);
};

std::string relationToText(
Expand Down Expand Up @@ -386,6 +312,88 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
return text.str();
}

void outputExtensionSpacesToBinaryPlan(
const SymbolTable& symbolTable,
::substrait::proto::Plan* plan) {
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kExtensionSpace) {
continue;
}

auto extensionData =
ANY_CAST(std::shared_ptr<ExtensionSpaceData>, info.blob);
auto uri = plan->add_extension_uris();
uri->set_uri(info.name);
uri->set_extension_uri_anchor(extensionData->anchorReference);
}
}

void outputFunctionsToBinaryPlan(
const SymbolTable& symbolTable,
::substrait::proto::Plan* plan) {
std::map<uint32_t, std::string> spaceNames;
std::set<uint32_t> usedSpaces;

// Look at the existing spaces.
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kExtensionSpace) {
continue;
}

auto extensionData =
ANY_CAST(std::shared_ptr<ExtensionSpaceData>, info.blob);
spaceNames.insert(
std::make_pair(extensionData->anchorReference, info.name));
}

// Find any spaces that are used but undefined.
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kFunction) {
continue;
}

auto extension = ANY_CAST(std::shared_ptr<FunctionData>, info.blob);
if (extension->extensionUriReference.has_value()) {
usedSpaces.insert(extension->extensionUriReference.value());
}
}

// Output the extensions by space in the order they were encountered.
for (const uint32_t space : usedSpaces) {
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kFunction) {
continue;
}

auto functionData = ANY_CAST(std::shared_ptr<FunctionData>, info.blob);
if (functionData->extensionUriReference != space) {
continue;
}

auto func = plan->add_extensions()->mutable_extension_function();
func->set_function_anchor(functionData->anchor);
func->set_name(functionData->name);

if (spaceNames.find(space) != spaceNames.end()) {
func->set_extension_uri_reference(space);
}
}
}

for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kFunction) {
continue;
}

auto functionData = ANY_CAST(std::shared_ptr<FunctionData>, info.blob);
if (!functionData->extensionUriReference.has_value()) {
auto func = plan->add_extensions()->mutable_extension_function();
func->set_function_anchor(functionData->anchor);
func->set_name(functionData->name);
}
}
}

} // namespace

std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) {
Expand Down Expand Up @@ -631,6 +639,8 @@ void SymbolTablePrinter::addInputsToRelation(
::substrait::proto::Plan SymbolTablePrinter::outputToBinaryPlan(
const SymbolTable& symbolTable) {
::substrait::proto::Plan plan;
outputExtensionSpacesToBinaryPlan(symbolTable, &plan);
outputFunctionsToBinaryPlan(symbolTable, &plan);
for (const SymbolInfo& info : symbolTable) {
if (info.type != SymbolType::kRelation) {
continue;
Expand Down
9 changes: 7 additions & 2 deletions src/substrait/textplan/converter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ set(TEXTPLAN_SRCS
add_library(substrait_textplan_converter ${TEXTPLAN_SRCS})

target_link_libraries(
substrait_textplan_converter substrait_common substrait_expression
substrait_proto symbol_table error_listener)
substrait_textplan_converter
substrait_common
substrait_expression
substrait_proto
symbol_table
error_listener
date::date)

if(${SUBSTRAIT_CPP_BUILD_TESTING})
add_subdirectory(tests)
Expand Down
6 changes: 3 additions & 3 deletions src/substrait/textplan/converter/InitialPlanProtoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ std::any InitialPlanProtoVisitor::visitNamedStruct(
void InitialPlanProtoVisitor::addFieldsToRelation(
const std::shared_ptr<RelationData>& relationData,
const ::substrait::proto::Rel& relation) {
auto symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation));
if (symbol == SymbolInfo::kUnknown || symbol.type != SymbolType::kRelation) {
auto* symbol = symbolTable_->lookupSymbolByLocation(PROTO_LOCATION(relation));
if (symbol == nullptr || symbol->type != SymbolType::kRelation) {
return;
}
auto symbolRelationData =
ANY_CAST(std::shared_ptr<RelationData>, symbol.blob);
ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);
for (const auto& field : symbolRelationData->fieldReferences) {
relationData->fieldReferences.push_back(field);
}
Expand Down
Loading

0 comments on commit fee0829

Please sign in to comment.