Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for all subquery types #89

Merged
merged 17 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/substrait/textplan/ParseResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class ParseResult {
return errors;
}

void addErrors(const std::vector<std::string>& errors) {
syntaxErrors_.insert(syntaxErrors_.end(), errors.begin(), errors.end());
}

// Add the capability for ::testing::PrintToString to print ParseResult.
friend std::ostream& operator<<(std::ostream& os, const ParseResult& result);

Expand Down
6 changes: 6 additions & 0 deletions src/substrait/textplan/StructuredSymbolData.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct RelationData {
const SymbolInfo* continuingPipeline{nullptr};
// The next nodes in the pipelines that this node starts.
std::vector<const SymbolInfo*> newPipelines;
// Expressions in this relation consume subqueries with these symbols.
std::vector<const SymbolInfo*> subQueryPipelines;

// The information corresponding to the relation without any references to
// other relations or inputs.
Expand All @@ -46,6 +48,10 @@ struct RelationData {
// references to the symbol would use the alias.)
std::map<size_t, std::string> generatedFieldReferenceAlternativeExpression;

// Temporary storage for global aliases for expressions. Used during the
// construction of a relation.
std::map<size_t, std::string> generatedFieldReferenceAliases;

// If populated, supersedes the combination of fieldReferences and
// generatedFieldReferences for the field symbols exposed by this relation.
std::vector<const SymbolInfo*> outputFieldReferences;
Expand Down
4 changes: 4 additions & 0 deletions src/substrait/textplan/SubstraitErrorListener.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class SubstraitErrorListener {
addError(-1, -1, msg);
};

void addErrorInstances(const std::vector<ErrorInstance>& errors) {
errors_.insert(errors_.end(), errors.begin(), errors.end());
}

const std::vector<ErrorInstance>& getErrors() {
return errors_;
};
Expand Down
42 changes: 39 additions & 3 deletions src/substrait/textplan/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ const SymbolInfo SymbolInfo::kUnknown = {
std::nullopt};

bool operator==(const SymbolInfo& left, const SymbolInfo& right) {
return (left.name == right.name) && (left.location == right.location) &&
return (left.name == right.name) &&
(left.sourceLocation == right.sourceLocation) &&
(left.type == right.type);
}

Expand Down Expand Up @@ -118,13 +119,31 @@ size_t SymbolTable::findSymbolIndex(const SymbolInfo& symbol) {
return symbols_.size();
}

void SymbolTable::updateLocation(
void SymbolTable::addPermanentLocation(
const SymbolInfo& symbol,
const Location& location) {
auto index = findSymbolIndex(symbol);
symbols_[index]->permanentLocation = location;
symbolsByLocation_.insert(std::make_pair(location, index));
}

void SymbolTable::setParentQueryLocation(
const io::substrait::textplan::SymbolInfo& symbol,
const io::substrait::textplan::Location& location) {
auto index = findSymbolIndex(symbol);
symbols_[index]->parentQueryLocation = location;

int highestIndex = -1;
for (const auto& sym : symbols_) {
if (sym->parentQueryLocation == location) {
if (sym->parentQueryIndex > highestIndex) {
highestIndex = sym->parentQueryIndex;
}
}
}
symbols_[index]->parentQueryIndex = highestIndex + 1;
}

void SymbolTable::addAlias(const std::string& alias, const SymbolInfo* symbol) {
auto index = findSymbolIndex(*symbol);
symbols_[index]->alias = alias;
Expand Down Expand Up @@ -169,6 +188,19 @@ const SymbolInfo* SymbolTable::lookupSymbolByLocationAndTypes(
return nullptr;
}

const SymbolInfo* SymbolTable::lookupSymbolByParentQueryAndType(
const Location& location,
int index,
SymbolType type) const {
for (const auto& symbol : symbols_) {
if (symbol->parentQueryLocation == location &&
symbol->parentQueryIndex == index && symbol->type == type) {
return symbol.get();
}
}
return nullptr;
}

const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type)
const {
int count = 0;
Expand Down Expand Up @@ -200,7 +232,11 @@ std::string SymbolTable::toDebugString() const {
}
auto relationData = ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);
result << std::left << std::setw(4) << relationCount++;
result << std::left << std::setw(20) << symbol->name << std::endl;
result << std::left << std::setw(20) << symbol->name;
if (!relationData->subQueryPipelines.empty()) {
result << " SQC=" << relationData->subQueryPipelines.size();
}
result << std::endl;

int32_t fieldNum = 0;
for (const auto& field : relationData->fieldReferences) {
Expand Down
28 changes: 21 additions & 7 deletions src/substrait/textplan/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class SymbolType {
};

enum class RelationType {
// Logical plans
// Logical
kUnknown = 0,
kRead = 1,
kProject = 2,
Expand All @@ -45,11 +45,11 @@ enum class RelationType {
kFilter = 8,
kSet = 9,

// Physical plans
// Physical
kHashJoin = 31,
kMergeJoin = 32,

// Write relations, currently unreachable in Plan protos.
// Write
kExchange = 50,
kDdl = 51,
kWrite = 52,
Expand Down Expand Up @@ -79,7 +79,10 @@ struct SymbolInfo {
std::string name;
std::string alias{}; // If present, use this instead of name.
const SymbolInfo* schema{nullptr}; // The related schema symbol if present.
Location location;
Location sourceLocation;
Location permanentLocation{Location::kUnknownLocation};
Location parentQueryLocation{Location::kUnknownLocation};
int parentQueryIndex{-1};
SymbolType type;
std::any subtype;
std::any blob;
Expand All @@ -91,7 +94,7 @@ struct SymbolInfo {
std::any newSubtype,
std::any newBlob)
: name(std::move(newName)),
location(newLocation),
sourceLocation(newLocation),
type(newType),
subtype(std::move(newSubtype)),
blob(std::move(newBlob)){};
Expand Down Expand Up @@ -145,8 +148,14 @@ class SymbolTable {
const std::any& subtype,
const std::any& blob);

// Changes the location for a specified existing symbol.
void updateLocation(const SymbolInfo& symbol, const Location& location);
// Changes the permanent location (the version stored in the symbol table)
// for a specified existing symbol.
void addPermanentLocation(const SymbolInfo& symbol, const Location& location);

// Sets the location of the parent query.
void setParentQueryLocation(
const SymbolInfo& symbol,
const Location& location);

// Adds an alias to the given symbol.
void addAlias(const std::string& alias, const SymbolInfo* symbol);
Expand All @@ -165,6 +174,11 @@ class SymbolTable {
const Location& location,
std::unordered_set<SymbolType> types) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByParentQueryAndType(
const Location& location,
int index,
SymbolType type) const;

[[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type)
const;

Expand Down
Loading
Loading