Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add output field mapping #78

Merged
merged 10 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/substrait/textplan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ add_library(
symbol_table
Location.cpp
Location.h
StringManipulation.cpp
StringManipulation.h
SymbolTable.cpp
SymbolTable.h
SymbolTablePrinter.cpp
Expand Down
19 changes: 19 additions & 0 deletions src/substrait/textplan/StringManipulation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* SPDX-License-Identifier: Apache-2.0 */

#include "StringManipulation.h"

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(0, needle.size()) == needle;
}

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle) {
return haystack.size() > needle.size() &&
haystack.substr(haystack.size() - needle.size(), needle.size()) == needle;
}

} // namespace io::substrait::textplan
15 changes: 15 additions & 0 deletions src/substrait/textplan/StringManipulation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/* SPDX-License-Identifier: Apache-2.0 */

#pragma once

#include <string_view>

namespace io::substrait::textplan {

// Yields true if the string 'haystack' starts with the string 'needle'.
bool startsWith(std::string_view haystack, std::string_view needle);

// Returns true if the string 'haystack' ends with the string 'needle'.
bool endsWith(std::string_view haystack, std::string_view needle);

} // namespace io::substrait::textplan
14 changes: 14 additions & 0 deletions src/substrait/textplan/StructuredSymbolData.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ struct RelationData {
// Column name for each field known to this relation (in field order). Used
// to determine what fields are coming in as well and fields are going out.
std::vector<const SymbolInfo*> fieldReferences;

// Each field reference here was generated within the current relation.
std::vector<const SymbolInfo*> generatedFieldReferences;

// Local aliases for field references in this relation.
std::map<size_t, std::string> generatedFieldReferenceAlternativeExpression;
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved

// If populated, supersedes the combination of fieldReferences and
// generatedFieldReferences for the field symbols exposed by this relation.
std::vector<const SymbolInfo*> outputFieldReferences;

// Contains the field reference names seen so far along with the id of the
// first occurrence.
EpsilonPrime marked this conversation as resolved.
Show resolved Hide resolved
std::map<std::string, size_t> seenFieldReferenceNames;
};

// Used by Schema symbols to keep track of assigned values.
Expand Down
101 changes: 96 additions & 5 deletions src/substrait/textplan/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <string>

#include "substrait/common/Exceptions.h"
#include "substrait/textplan/Any.h"
#include "substrait/textplan/Location.h"
#include "substrait/textplan/StructuredSymbolData.h"

namespace io::substrait::textplan {

Expand Down Expand Up @@ -123,6 +125,12 @@ void SymbolTable::updateLocation(
symbolsByLocation_.insert(std::make_pair(location, index));
}

void SymbolTable::addAlias(const std::string& alias, const SymbolInfo* symbol) {
auto index = findSymbolIndex(*symbol);
symbols_[index]->alias = alias;
symbolsByName_.insert(std::make_pair(alias, index));
}

const SymbolInfo* SymbolTable::lookupSymbolByName(
const std::string& name) const {
auto itr = symbolsByName_.find(name);
Expand All @@ -132,13 +140,33 @@ const SymbolInfo* SymbolTable::lookupSymbolByName(
return symbols_[itr->second].get();
}

const SymbolInfo* SymbolTable::lookupSymbolByLocation(
std::vector<const SymbolInfo*> SymbolTable::lookupSymbolsByLocation(
const Location& location) const {
auto itr = symbolsByLocation_.find(location);
if (itr == symbolsByLocation_.end()) {
return nullptr;
std::vector<const SymbolInfo*> symbols;
auto [begin, end] = symbolsByLocation_.equal_range(location);
for (auto itr = begin; itr != end; ++itr) {
symbols.push_back(symbols_[itr->second].get());
}
return symbols_[itr->second].get();
return symbols;
}

const SymbolInfo* SymbolTable::lookupSymbolByLocationAndType(
const Location& location,
SymbolType type) const {
return lookupSymbolByLocationAndTypes(location, {type});
}

const SymbolInfo* SymbolTable::lookupSymbolByLocationAndTypes(
const Location& location,
std::unordered_set<SymbolType> types) const {
auto [begin, end] = symbolsByLocation_.equal_range(location);
for (auto itr = begin; itr != end; ++itr) {
auto symbol = symbols_[itr->second].get();
if (types.find(symbol->type) != types.end()) {
return symbol;
}
}
return nullptr;
}

const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type)
Expand All @@ -162,4 +190,67 @@ SymbolTableIterator SymbolTable::end() const {
return {this, symbols_.size()};
}

std::string SymbolTable::toDebugString() const {
std::stringstream result;
bool textAlreadyWritten = false;
int32_t relationCount = 0;
for (const auto& symbol : symbols_) {
if (symbol->type != SymbolType::kRelation) {
continue;
}
auto relationData = ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);
result << std::left << std::setw(4) << relationCount++;
result << std::left << std::setw(20) << symbol->name << std::endl;

int32_t fieldNum = 0;
for (const auto& field : relationData->fieldReferences) {
result << " " << std::setw(4) << fieldNum++ << " ";
if (field->schema != nullptr) {
result << field->schema->name << ".";
}
result << field->name;
if (!field->alias.empty()) {
result << " " << field->alias;
}
result << std::endl;
}

for (const auto& field : relationData->generatedFieldReferences) {
result << " g" << std::setw(4) << fieldNum++ << " ";
if (field->schema != nullptr) {
result << field->schema->name << ".";
}
result << field->name;
if (relationData->generatedFieldReferenceAlternativeExpression.find(
fieldNum) !=
relationData->generatedFieldReferenceAlternativeExpression.end()) {
result << " "
<< relationData
->generatedFieldReferenceAlternativeExpression[fieldNum];
} else if (!field->alias.empty()) {
result << " " << field->alias;
}
result << std::endl;
}

int32_t outputFieldNum = 0;
for (const auto& field : relationData->outputFieldReferences) {
result << " o" << std::setw(4) << outputFieldNum++ << " ";
if (field->schema != nullptr) {
result << field->schema->name << ".";
}
result << field->name;
if (!field->alias.empty()) {
result << " " << field->alias;
}
result << std::endl;
}
textAlreadyWritten = true;
}
if (textAlreadyWritten) {
result << std::endl;
}
return result.str();
}

} // namespace io::substrait::textplan
25 changes: 21 additions & 4 deletions src/substrait/textplan/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#pragma once

#include <any>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand All @@ -19,13 +20,14 @@ enum class SymbolType {
kFunction = 1,
kPlanRelation = 2,
kRelation = 3,
kRelationDetail = 4,
kSchema = 5,
kSchemaColumn = 6,
kSource = 7,
kSourceDetail = 8,
kField = 9,
kRoot = 10,
kTable = 11,
kMeasure = 12,

kUnknown = -1,
};
Expand Down Expand Up @@ -75,6 +77,8 @@ const std::string& symbolTypeName(SymbolType type);

struct SymbolInfo {
std::string name;
std::string alias{}; // If present, use this instead of name.
const SymbolInfo* schema{nullptr}; // The related schema symbol if present.
Location location;
SymbolType type;
std::any subtype;
Expand Down Expand Up @@ -144,12 +148,23 @@ class SymbolTable {
// Changes the location for a specified existing symbol.
void updateLocation(const SymbolInfo& symbol, const Location& location);

// Adds an alias to the given symbol.
void addAlias(const std::string& alias, const SymbolInfo* symbol);

[[nodiscard]] const SymbolInfo* lookupSymbolByName(
const std::string& name) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByLocation(
[[nodiscard]] std::vector<const SymbolInfo*> lookupSymbolsByLocation(
const Location& location) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByLocationAndType(
const Location& location,
SymbolType type) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByLocationAndTypes(
const Location& location,
std::unordered_set<SymbolType> types) const;

[[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type)
const;

Expand Down Expand Up @@ -177,6 +192,8 @@ class SymbolTable {
return os;
}

[[nodiscard]] std::string toDebugString() const;

private:
// Returns the table size if the symbol is not found.
size_t findSymbolIndex(const SymbolInfo& symbol);
Expand All @@ -187,7 +204,7 @@ class SymbolTable {

std::vector<std::shared_ptr<SymbolInfo>> symbols_;
std::unordered_map<std::string, size_t> symbolsByName_;
std::unordered_map<Location, size_t> symbolsByLocation_;
std::multimap<Location, size_t> symbolsByLocation_;
};

} // namespace io::substrait::textplan
14 changes: 5 additions & 9 deletions src/substrait/textplan/SymbolTablePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ std::string outputSchemaSection(const SymbolTable& symbolTable) {
if (info.type != SymbolType::kSchema) {
continue;
}

if (info.blob.type() != typeid(const ::substrait::proto::NamedStruct*)) {
// TODO -- Implement schemas for text plans.
if (!info.blob.has_value()) {
continue;
}

Expand Down Expand Up @@ -241,10 +239,6 @@ std::string outputSourcesSection(const SymbolTable& symbolTable) {
if (hasPreviousText) {
text << "\n";
}
if (info.subtype.type() != typeid(SourceType)) {
// TODO -- Implement sources for text plans.
continue;
}
auto subtype = ANY_CAST(SourceType, info.subtype);
switch (subtype) {
case SourceType::kNamedTable: {
Expand Down Expand Up @@ -300,6 +294,7 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
}
return spaceNames.at(a) < spaceNames.at(b);
};
// Sorted by name if we have one, otherwise by space id.
std::set<uint32_t, decltype(cmp)> usedSpaces(cmp);

// Look at the existing spaces.
Expand Down Expand Up @@ -352,8 +347,8 @@ std::string outputFunctionsSection(const SymbolTable& symbolTable) {
functionsToOutput.emplace_back(info.name, functionData->name);
}
std::sort(functionsToOutput.begin(), functionsToOutput.end());
for (const auto& item : functionsToOutput) {
text << " function " << item.second << " as " << item.first << ";\n";
for (auto [shortName, canonicalName] : functionsToOutput) {
text << " function " << canonicalName << " as " << shortName << ";\n";
}
text << "}\n";
hasPreviousOutput = true;
Expand Down Expand Up @@ -446,6 +441,7 @@ void outputFunctionsToBinaryPlan(

} // namespace

// TODO -- Update so that errors occurring during printing are captured.
std::string SymbolTablePrinter::outputToText(const SymbolTable& symbolTable) {
std::stringstream text;
bool hasPreviousText = false;
Expand Down
Loading