Skip to content

Commit

Permalink
feat: add output field mapping (#78)
Browse files Browse the repository at this point in the history
features:
* updated the symbol table so that multiple symbols can share the same
location
* the root relation now contains both a relation symbol and a name
structure symbol
   * moved commonly string search functions into a separate file
   * added EMIT
   * added aliases
   * added join types

fixes:
   * fixed root names sort order
  • Loading branch information
EpsilonPrime committed Jul 19, 2023
1 parent f4e9fd1 commit 0a71a2f
Show file tree
Hide file tree
Showing 29 changed files with 1,596 additions and 309 deletions.
2 changes: 2 additions & 0 deletions src/substrait/textplan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ add_library(
symbol_table
Location.cpp
Location.h
StringManipulation.cpp
StringManipulation.h
SymbolTable.cpp
SymbolTable.h
SymbolTablePrinter.cpp
Expand Down
19 changes: 19 additions & 0 deletions src/substrait/textplan/StringManipulation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "StringManipulation.h"

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(0, needle.size()) == needle;
}

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(haystack.size() - needle.size(), needle.size()) == needle;
}

} // namespace io::substrait::textplan
15 changes: 15 additions & 0 deletions src/substrait/textplan/StringManipulation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string_view>

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle);

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle);

} // namespace io::substrait::textplan
17 changes: 17 additions & 0 deletions src/substrait/textplan/StructuredSymbolData.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ struct RelationData {
// Column name for each field known to this relation (in field order). Used
// to determine what fields are coming in as well and fields are going out.
std::vector<const SymbolInfo*> fieldReferences;

// Each field reference here was generated within the current relation.
std::vector<const SymbolInfo*> generatedFieldReferences;

// Local aliases for field references in this relation. Used to replace the
// normal form symbols would take for this relation's use only. (Later
// references to the symbol would use the alias.)
std::map<size_t, std::string> generatedFieldReferenceAlternativeExpression;

// If populated, supersedes the combination of fieldReferences and
// generatedFieldReferences for the field symbols exposed by this relation.
std::vector<const SymbolInfo*> outputFieldReferences;

// Contains the field reference names seen so far while processing this
// relation along with the id of the first occurrence. Used to detect when
// fully qualified references are necessary.
std::map<std::string, size_t> seenFieldReferenceNames;
};

// Used by Schema symbols to keep track of assigned values.
Expand Down
101 changes: 96 additions & 5 deletions src/substrait/textplan/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <string>

#include "substrait/common/Exceptions.h"
#include "substrait/textplan/Any.h"
#include "substrait/textplan/Location.h"
#include "substrait/textplan/StructuredSymbolData.h"

namespace io::substrait::textplan {

Expand Down Expand Up @@ -123,6 +125,12 @@ void SymbolTable::updateLocation(
symbolsByLocation_.insert(std::make_pair(location, index));
}

void SymbolTable::addAlias(const std::string& alias, const SymbolInfo* symbol) {
auto index = findSymbolIndex(*symbol);
symbols_[index]->alias = alias;
symbolsByName_.insert(std::make_pair(alias, index));
}

const SymbolInfo* SymbolTable::lookupSymbolByName(
const std::string& name) const {
auto itr = symbolsByName_.find(name);
Expand All @@ -132,13 +140,33 @@ const SymbolInfo* SymbolTable::lookupSymbolByName(
return symbols_[itr->second].get();
}

const SymbolInfo* SymbolTable::lookupSymbolByLocation(
std::vector<const SymbolInfo*> SymbolTable::lookupSymbolsByLocation(
const Location& location) const {
auto itr = symbolsByLocation_.find(location);
if (itr == symbolsByLocation_.end()) {
return nullptr;
std::vector<const SymbolInfo*> symbols;
auto [begin, end] = symbolsByLocation_.equal_range(location);
for (auto itr = begin; itr != end; ++itr) {
symbols.push_back(symbols_[itr->second].get());
}
return symbols_[itr->second].get();
return symbols;
}

const SymbolInfo* SymbolTable::lookupSymbolByLocationAndType(
const Location& location,
SymbolType type) const {
return lookupSymbolByLocationAndTypes(location, {type});
}

const SymbolInfo* SymbolTable::lookupSymbolByLocationAndTypes(
const Location& location,
std::unordered_set<SymbolType> types) const {
auto [begin, end] = symbolsByLocation_.equal_range(location);
for (auto itr = begin; itr != end; ++itr) {
auto symbol = symbols_[itr->second].get();
if (types.find(symbol->type) != types.end()) {
return symbol;
}
}
return nullptr;
}

const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type)
Expand All @@ -162,4 +190,67 @@ SymbolTableIterator SymbolTable::end() const {
return {this, symbols_.size()};
}

std::string SymbolTable::toDebugString() const {
std::stringstream result;
bool textAlreadyWritten = false;
int32_t relationCount = 0;
for (const auto& symbol : symbols_) {
if (symbol->type != SymbolType::kRelation) {
continue;
}
auto relationData = ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);
result << std::left << std::setw(4) << relationCount++;
result << std::left << std::setw(20) << symbol->name << std::endl;

int32_t fieldNum = 0;
for (const auto& field : relationData->fieldReferences) {
result << " " << std::setw(4) << fieldNum++ << " ";
if (field->schema != nullptr) {
result << field->schema->name << ".";
}
result << field->name;
if (!field->alias.empty()) {
result << " " << field->alias;
}
result << std::endl;
}

for (const auto& field : relationData->generatedFieldReferences) {
result << " g" << std::setw(4) << fieldNum++ << " ";
if (field->schema != nullptr) {
result << field->schema->name << ".";
}
result << field->name;
if (relationData->generatedFieldReferenceAlternativeExpression.find(
fieldNum) !=
relationData->generatedFieldReferenceAlternativeExpression.end()) {
result << " "
<< relationData
->generatedFieldReferenceAlternativeExpression[fieldNum];
} else if (!field->alias.empty()) {
result << " " << field->alias;
}
result << std::endl;
}

int32_t outputFieldNum = 0;
for (const auto& field : relationData->outputFieldReferences) {
result << " o" << std::setw(4) << outputFieldNum++ << " ";
if (field->schema != nullptr) {
result << field->schema->name << ".";
}
result << field->name;
if (!field->alias.empty()) {
result << " " << field->alias;
}
result << std::endl;
}
textAlreadyWritten = true;
}
if (textAlreadyWritten) {
result << std::endl;
}
return result.str();
}

} // namespace io::substrait::textplan
25 changes: 21 additions & 4 deletions src/substrait/textplan/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#pragma once

#include <any>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand All @@ -19,13 +20,14 @@ enum class SymbolType {
kFunction = 1,
kPlanRelation = 2,
kRelation = 3,
kRelationDetail = 4,
kSchema = 5,
kSchemaColumn = 6,
kSource = 7,
kSourceDetail = 8,
kField = 9,
kRoot = 10,
kTable = 11,
kMeasure = 12,

kUnknown = -1,
};
Expand Down Expand Up @@ -75,6 +77,8 @@ const std::string& symbolTypeName(SymbolType type);

struct SymbolInfo {
std::string name;
std::string alias{}; // If present, use this instead of name.
const SymbolInfo* schema{nullptr}; // The related schema symbol if present.
Location location;
SymbolType type;
std::any subtype;
Expand Down Expand Up @@ -144,12 +148,23 @@ class SymbolTable {
// Changes the location for a specified existing symbol.
void updateLocation(const SymbolInfo& symbol, const Location& location);

// Adds an alias to the given symbol.
void addAlias(const std::string& alias, const SymbolInfo* symbol);

[[nodiscard]] const SymbolInfo* lookupSymbolByName(
const std::string& name) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByLocation(
[[nodiscard]] std::vector<const SymbolInfo*> lookupSymbolsByLocation(
const Location& location) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByLocationAndType(
const Location& location,
SymbolType type) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByLocationAndTypes(
const Location& location,
std::unordered_set<SymbolType> types) const;

[[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type)
const;

Expand Down Expand Up @@ -177,6 +192,8 @@ class SymbolTable {
return os;
}

[[nodiscard]] std::string toDebugString() const;

private:
// Returns the table size if the symbol is not found.
size_t findSymbolIndex(const SymbolInfo& symbol);
Expand All @@ -187,7 +204,7 @@ class SymbolTable {

std::vector<std::shared_ptr<SymbolInfo>> symbols_;
std::unordered_map<std::string, size_t> symbolsByName_;
std::unordered_map<Location, size_t> symbolsByLocation_;
std::multimap<Location, size_t> symbolsByLocation_;
};

} // namespace io::substrait::textplan
14 changes: 5 additions & 9 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ std::string outputSchemaSection(const SymbolTable& symbolTable) {
if (info.type != SymbolType::kSchema) {
continue;
}

if (info.blob.type() != typeid(const ::substrait::proto::NamedStruct*)) {
// TODO -- Implement schemas for text plans.
if (!info.blob.has_value()) {
continue;
}

Expand Down Expand Up @@ -241,10 +239,6 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) {
if (hasPreviousText) {
text << "\n";
}
if (info.subtype.type() != typeid(SourceType)) {
// TODO -- Implement sources for text plans.
continue;
}
auto subtype = ANY_CAST(SourceType, info.subtype);
switch (subtype) {
case SourceType::kNamedTable: {
Expand Down Expand Up @@ -300,6 +294,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
}
return spaceNames.at(a) < spaceNames.at(b);
};
// Sorted by name if we have one, otherwise by space id.
std::set<uint32_t, decltype(cmp)> usedSpaces(cmp);

// Look at the existing spaces.
Expand Down Expand Up @@ -352,8 +347,8 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (const auto& item : functionsToOutput) {
text << " function " << item.second << " as " << item.first << ";\n";
for (auto [shortName, canonicalName] : functionsToOutput) {
text << " function " << canonicalName << " as " << shortName << ";\n";
}
text << "}\n";
hasPreviousOutput = true;
Expand Down Expand Up @@ -446,6 +441,7 @@ void outputFunctionsToBinaryPlan(

} // namespace

// TODO -- Update so that errors occurring during printing are captured.
std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) {
std::stringstream text;
bool hasPreviousText = false;
Expand Down
Loading

0 comments on commit 0a71a2f

Please sign in to comment.