Skip to content

Commit

Permalink
feat: add support for all subquery types (substrait-io#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime authored Feb 7, 2024
1 parent 66c8d97 commit 4f6ed2f
Show file tree
Hide file tree
Showing 44 changed files with 13,702 additions and 264 deletions.
4 changes: 4 additions & 0 deletions src/substrait/textplan/ParseResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class ParseResult {
return errors;
}

void addErrors(const std::vector<std::string>& errors) {
syntaxErrors_.insert(syntaxErrors_.end(), errors.begin(), errors.end());
}

// Add the capability for ::testing::PrintToString to print ParseResult.
friend std::ostream& operator<<(std::ostream& os, const ParseResult& result);

Expand Down
6 changes: 6 additions & 0 deletions src/substrait/textplan/StructuredSymbolData.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct RelationData {
const SymbolInfo* continuingPipeline{nullptr};
// The next nodes in the pipelines that this node starts.
std::vector<const SymbolInfo*> newPipelines;
// Expressions in this relation consume subqueries with these symbols.
std::vector<const SymbolInfo*> subQueryPipelines;

// The information corresponding to the relation without any references to
// other relations or inputs.
Expand All @@ -46,6 +48,10 @@ struct RelationData {
// references to the symbol would use the alias.)
std::map<size_t, std::string> generatedFieldReferenceAlternativeExpression;

// Temporary storage for global aliases for expressions. Used during the
// construction of a relation.
std::map<size_t, std::string> generatedFieldReferenceAliases;

// If populated, supersedes the combination of fieldReferences and
// generatedFieldReferences for the field symbols exposed by this relation.
std::vector<const SymbolInfo*> outputFieldReferences;
Expand Down
4 changes: 4 additions & 0 deletions src/substrait/textplan/SubstraitErrorListener.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class SubstraitErrorListener {
addError(-1, -1, msg);
};

void addErrorInstances(const std::vector<ErrorInstance>& errors) {
errors_.insert(errors_.end(), errors.begin(), errors.end());
}

const std::vector<ErrorInstance>& getErrors() {
return errors_;
};
Expand Down
42 changes: 39 additions & 3 deletions src/substrait/textplan/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ const SymbolInfo SymbolInfo::kUnknown = {
std::nullopt};

bool operator==(const SymbolInfo& left, const SymbolInfo& right) {
return (left.name == right.name) && (left.location == right.location) &&
return (left.name == right.name) &&
(left.sourceLocation == right.sourceLocation) &&
(left.type == right.type);
}

Expand Down Expand Up @@ -118,13 +119,31 @@ size_t SymbolTable::findSymbolIndex(const SymbolInfo& symbol) {
return symbols_.size();
}

void SymbolTable::updateLocation(
void SymbolTable::addPermanentLocation(
const SymbolInfo& symbol,
const Location& location) {
auto index = findSymbolIndex(symbol);
symbols_[index]->permanentLocation = location;
symbolsByLocation_.insert(std::make_pair(location, index));
}

void SymbolTable::setParentQueryLocation(
const io::substrait::textplan::SymbolInfo& symbol,
const io::substrait::textplan::Location& location) {
auto index = findSymbolIndex(symbol);
symbols_[index]->parentQueryLocation = location;

int highestIndex = -1;
for (const auto& sym : symbols_) {
if (sym->parentQueryLocation == location) {
if (sym->parentQueryIndex > highestIndex) {
highestIndex = sym->parentQueryIndex;
}
}
}
symbols_[index]->parentQueryIndex = highestIndex + 1;
}

void SymbolTable::addAlias(const std::string& alias, const SymbolInfo* symbol) {
auto index = findSymbolIndex(*symbol);
symbols_[index]->alias = alias;
Expand Down Expand Up @@ -169,6 +188,19 @@ const SymbolInfo* SymbolTable::lookupSymbolByLocationAndTypes(
return nullptr;
}

const SymbolInfo* SymbolTable::lookupSymbolByParentQueryAndType(
const Location& location,
int index,
SymbolType type) const {
for (const auto& symbol : symbols_) {
if (symbol->parentQueryLocation == location &&
symbol->parentQueryIndex == index && symbol->type == type) {
return symbol.get();
}
}
return nullptr;
}

const SymbolInfo& SymbolTable::nthSymbolByType(uint32_t n, SymbolType type)
const {
int count = 0;
Expand Down Expand Up @@ -200,7 +232,11 @@ std::string SymbolTable::toDebugString() const {
}
auto relationData = ANY_CAST(std::shared_ptr<RelationData>, symbol->blob);
result << std::left << std::setw(4) << relationCount++;
result << std::left << std::setw(20) << symbol->name << std::endl;
result << std::left << std::setw(20) << symbol->name;
if (!relationData->subQueryPipelines.empty()) {
result << " SQC=" << relationData->subQueryPipelines.size();
}
result << std::endl;

int32_t fieldNum = 0;
for (const auto& field : relationData->fieldReferences) {
Expand Down
28 changes: 21 additions & 7 deletions src/substrait/textplan/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class SymbolType {
};

enum class RelationType {
// Logical plans
// Logical
kUnknown = 0,
kRead = 1,
kProject = 2,
Expand All @@ -45,11 +45,11 @@ enum class RelationType {
kFilter = 8,
kSet = 9,

// Physical plans
// Physical
kHashJoin = 31,
kMergeJoin = 32,

// Write relations, currently unreachable in Plan protos.
// Write
kExchange = 50,
kDdl = 51,
kWrite = 52,
Expand Down Expand Up @@ -79,7 +79,10 @@ struct SymbolInfo {
std::string name;
std::string alias{}; // If present, use this instead of name.
const SymbolInfo* schema{nullptr}; // The related schema symbol if present.
Location location;
Location sourceLocation;
Location permanentLocation{Location::kUnknownLocation};
Location parentQueryLocation{Location::kUnknownLocation};
int parentQueryIndex{-1};
SymbolType type;
std::any subtype;
std::any blob;
Expand All @@ -91,7 +94,7 @@ struct SymbolInfo {
std::any newSubtype,
std::any newBlob)
: name(std::move(newName)),
location(newLocation),
sourceLocation(newLocation),
type(newType),
subtype(std::move(newSubtype)),
blob(std::move(newBlob)){};
Expand Down Expand Up @@ -145,8 +148,14 @@ class SymbolTable {
const std::any& subtype,
const std::any& blob);

// Changes the location for a specified existing symbol.
void updateLocation(const SymbolInfo& symbol, const Location& location);
// Changes the permanent location (the version stored in the symbol table)
// for a specified existing symbol.
void addPermanentLocation(const SymbolInfo& symbol, const Location& location);

// Sets the location of the parent query.
void setParentQueryLocation(
const SymbolInfo& symbol,
const Location& location);

// Adds an alias to the given symbol.
void addAlias(const std::string& alias, const SymbolInfo* symbol);
Expand All @@ -165,6 +174,11 @@ class SymbolTable {
const Location& location,
std::unordered_set<SymbolType> types) const;

[[nodiscard]] const SymbolInfo* lookupSymbolByParentQueryAndType(
const Location& location,
int index,
SymbolType type) const;

[[nodiscard]] const SymbolInfo& nthSymbolByType(uint32_t n, SymbolType type)
const;

Expand Down
Loading

0 comments on commit 4f6ed2f

Please sign in to comment.