Skip to content

Commit

Permalink
Support self-referencing ACCESSED_BY foreign keys
Browse files Browse the repository at this point in the history
* Modify sqlengine/create.cc to correctly handle ACCESSED_BY self references
* Throw an error if self reference is OWNED_BY, OWNS, or ACCESSES
* Add tests to ensure self reference works correctly, including anonymization
* Fix visual bug in EXPLAIN COMPLIANCE
* Decrement user count in gdpr_forget.cc
  • Loading branch information
artemagvanian authored and KinanBab committed May 23, 2023
1 parent 6c5ea62 commit 027a9b9
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 15 deletions.
1 change: 1 addition & 0 deletions k9db/explain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void ReportShardingInformation(const shards::TableName &table_name,
default:
LOG(FATAL) << "UNREACHABLE";
}
last_table = info.next_table();
// << info->next_table() << "(" << tinfo.upcolumn() << ")";
out << std::endl;
}
Expand Down
49 changes: 37 additions & 12 deletions k9db/shards/sqlengine/create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,40 @@ absl::StatusOr<CreateContext::Annotations> CreateContext::DiscoverValidate() {
// Make sure all FK point to existing tables.
const sqlast::ColumnConstraint &fk = col.GetForeignKeyConstraint();
const auto &[foreign_table, foreign_column, fk_type] = fk.ForeignKey();
ASSERT_RET(this->sstate_.TableExists(foreign_table), InvalidArgument,
"FK points to nonexisting table");
const Table &target = this->sstate_.GetTable(foreign_table);

// In case this is a self-referring FK.
bool self_fk = foreign_table == this->stmt_.table_name();
if (!self_fk) {
ASSERT_RET(this->sstate_.TableExists(foreign_table), InvalidArgument,
"FK points to nonexisting table");
}

const Table &target =
self_fk ? this->table_ : this->sstate_.GetTable(foreign_table);
ASSERT_RET(target.schema.HasColumn(foreign_column), InvalidArgument,
"FK points to nonexisting column");

// Check if this points to the PK.
size_t index = target.schema.IndexOf(foreign_column);
bool points_to_pk = target.schema.keys().at(0) == index;
// Check if this points to the PK, handle various annotations.
bool foreign_owned, foreign_accessed, points_to_pk;
if (self_fk) {
foreign_owned = !annotations.explicit_owners.empty() ||
!annotations.implicit_owners.empty() ||
this->stmt_.IsDataSubject();
foreign_accessed = !annotations.accessors.empty() || foreign_owned;
points_to_pk =
this->stmt_.GetColumn(foreign_column)
.HasConstraint(sqlast::ColumnConstraint::Type::PRIMARY_KEY);
} else {
foreign_owned = this->sstate_.IsOwned(foreign_table);
foreign_accessed = this->sstate_.IsAccessed(foreign_table);
points_to_pk =
target.schema.keys().at(0) == target.schema.IndexOf(foreign_column);
}

// Handle various annotations.
bool foreign_owned = this->sstate_.IsOwned(foreign_table);
bool foreign_accessed = this->sstate_.IsAccessed(foreign_table);
if (fk_type == sqlast::ColumnConstraint::FKType::OWNED_BY) {
ASSERT_RET(foreign_owned, InvalidArgument, "OWNER to a non data subject");
ASSERT_RET(points_to_pk, InvalidArgument, "OWNER doesn't point to PK");
ASSERT_RET(!self_fk, InvalidArgument, "OWNER on a self-referencing FK");
annotations.explicit_owners.push_back(i);
} else if (fk_type == sqlast::ColumnConstraint::FKType::ACCESSED_BY) {
ASSERT_RET(foreign_accessed, InvalidArgument,
Expand All @@ -57,6 +75,7 @@ absl::StatusOr<CreateContext::Annotations> CreateContext::DiscoverValidate() {
annotations.accessors.push_back(i);
} else if (fk_type == sqlast::ColumnConstraint::FKType::OWNS) {
ASSERT_RET(points_to_pk, InvalidArgument, "OWNS doesn't point to PK");
ASSERT_RET(!self_fk, InvalidArgument, "OWNS on a self-referencing FK");
annotations.owns.push_back(i);
} else if (fk_type == sqlast::ColumnConstraint::FKType::ACCESSES) {
ASSERT_RET(points_to_pk, InvalidArgument, "ACCESSES doesn't point to PK");
Expand Down Expand Up @@ -99,7 +118,11 @@ std::vector<std::unique_ptr<ShardDescriptor>> CreateContext::MakeFDescriptors(
const std::string &fk_colname = fk_col.column_name();
const sqlast::ColumnConstraint &fk = fk_col.GetForeignKeyConstraint();
const auto &[next_table, next_col, _] = fk.ForeignKey();
Table &tbl = this->sstate_.GetTable(next_table);

Table &tbl = next_table == this->table_name_
? this->table_
: this->sstate_.GetTable(next_table);

size_t next_col_index = tbl.schema.IndexOf(next_col);
const std::vector<std::unique_ptr<ShardDescriptor>> &vec =
owners ? tbl.owners : tbl.accessors;
Expand All @@ -117,11 +140,13 @@ std::vector<std::unique_ptr<ShardDescriptor>> CreateContext::MakeFDescriptors(
// First time we see this shard_kind.
ShardDescriptor descriptor;
descriptor.shard_kind = shard_kind;
if (shard_kind == next_table) { // Direct sharding.
if (shard_kind == next_table && next_table != this->table_name_) {
// Direct sharding.
descriptor.type = InfoType::DIRECT;
descriptor.info = DirectInfo{fk_colname, fk_column_index, fk_column_type,
next_col, next_col_index};
} else { // Transitive sharding.
} else {
// Transitive sharding.
descriptor.type = InfoType::TRANSITIVE;
IndexDescriptor *index = nullptr;
if (create_indices) {
Expand Down
3 changes: 3 additions & 0 deletions k9db/shards/sqlengine/gdpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ absl::Status GDPRContext::RecurseOverDependents() {
*/
absl::Status GDPRContext::RecurseOverAccessDependents(
const std::string &table_name, std::vector<dataflow::Record> &&records) {
if (records.empty()) {
return absl::OkStatus();
}
// Iterate over all the access dependents.
const Table &table = this->sstate_.GetTable(table_name);
for (const auto &[next_table, desc] : table.access_dependents) {
Expand Down
3 changes: 3 additions & 0 deletions k9db/shards/sqlengine/gdpr_forget.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ absl::StatusOr<sql::SqlResult> GDPRForgetContext::Exec() {
this->db_->CommitTransaction();
CHECK_STATUS(this->conn_->ctx->CommitCheckpoint());

// Decrement users.
this->sstate_.DecrementUsers(this->shard_kind_, 1);

// Update dataflow.
for (auto &[table_name, records] : this->records_) {
this->dstate_.ProcessRecords(table_name, std::move(records));
Expand Down
52 changes: 52 additions & 0 deletions k9db/shards/sqlengine/gdpr_forget_anon_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,58 @@ TEST_F(GDPRForgetAnonTest, ComplexVariableAccessorshipAnon) {
db->RollbackTransaction();
}

TEST_F(GDPRForgetAnonTest, SelfFKTable) {
// Parse create table statements.
std::string commenters =
MakeCreate("commenters", {"commenterHex" STR PK}, true);
std::string comments = MakeCreate(
"comments",
{"commentHex" STR PK, "commenterHex" STR OB "commenters(commenterHex)",
"parentHex" STR AB "comments(commentHex)"},
false, ", " ON_DEL "parentHex" DEL_ROW);

// Make a k9db connection.
Connection conn = CreateConnection();
sql::Session *db = conn.session.get();

// Create the tables.
EXPECT_SUCCESS(Execute(commenters, &conn));
EXPECT_SUCCESS(Execute(comments, &conn));

// Perform some inserts.
auto &&[cr0_stmt, cr0] = MakeInsert("commenters", {"'0'"});
auto &&[cr1_stmt, cr1] = MakeInsert("commenters", {"'1'"});
auto &&[cr2_stmt, cr2] = MakeInsert("commenters", {"'2'"});

EXPECT_UPDATE(Execute(cr0_stmt, &conn), 1);
EXPECT_UPDATE(Execute(cr1_stmt, &conn), 1);
EXPECT_UPDATE(Execute(cr2_stmt, &conn), 1);

auto &&[c0_stmt, c0] = MakeInsert("comments", {"'0'", "'0'", "NULL"});
auto &&[c1_stmt, c1] = MakeInsert("comments", {"'1'", "'0'", "'0'"});
auto &&[c2_stmt, c2] = MakeInsert("comments", {"'2'", "'1'", "'1'"});
auto &&[c3_stmt, c3] = MakeInsert("comments", {"'3'", "'2'", "'2'"});
auto &&[c4_stmt, c4] = MakeInsert("comments", {"'4'", "'0'", "'3'"});
auto &&[c5_stmt, c5] = MakeInsert("comments", {"'5'", "'1'", "'3'"});

EXPECT_UPDATE(Execute(c0_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c1_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c2_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c3_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c4_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c5_stmt, &conn), 1);

// Validate get.
std::string forget = MakeGDPRForget("commenters", "0");
EXPECT_UPDATE(Execute(forget, &conn), 7);

db->BeginTransaction(false);
EXPECT_EQ(db->GetShard("comments", SN("commenter", "0")), (V{}));
EXPECT_EQ(db->GetShard("comments", SN("commenter", "1")), (V{}));
EXPECT_EQ(db->GetShard("comments", SN("commenter", "2")), (V{}));
db->RollbackTransaction();
}

} // namespace sqlengine
} // namespace shards
} // namespace k9db
47 changes: 47 additions & 0 deletions k9db/shards/sqlengine/gdpr_get_anon_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,53 @@ TEST_F(GDPRGetAnonTest, ComplexVariableAccessorship) {
(V{fanon1, fanon2})}));
}

TEST_F(GDPRGetAnonTest, SelfFKTable) {
// Parse create table statements.
std::string commenters =
MakeCreate("commenters", {"commenterHex" STR PK}, true);
std::string comments = MakeCreate(
"comments",
{"commentHex" STR PK, "commenterHex" STR OB "commenters(commenterHex)",
"parentHex" STR AB "comments(commentHex)"},
false, ", " ON_GET "parentHex" ANON "(commenterHex)");

// Make a k9db connection.
Connection conn = CreateConnection();

// Create the tables.
EXPECT_SUCCESS(Execute(commenters, &conn));
EXPECT_SUCCESS(Execute(comments, &conn));

// Perform some inserts.
auto &&[cr0_stmt, cr0] = MakeInsert("commenters", {"'0'"});
auto &&[cr1_stmt, cr1] = MakeInsert("commenters", {"'1'"});
auto &&[cr2_stmt, cr2] = MakeInsert("commenters", {"'2'"});

EXPECT_UPDATE(Execute(cr0_stmt, &conn), 1);
EXPECT_UPDATE(Execute(cr1_stmt, &conn), 1);
EXPECT_UPDATE(Execute(cr2_stmt, &conn), 1);

auto &&[c0_stmt, c0] = MakeInsert("comments", {"'0'", "'0'", "NULL"});
auto &&[c1_stmt, c1] = MakeInsert("comments", {"'1'", "'0'", "'0'"});
auto &&[c2_stmt, c2] = MakeInsert("comments", {"'2'", "'1'", "'1'"});
auto &&[c3_stmt, c3] = MakeInsert("comments", {"'3'", "'2'", "'2'"});
auto &&[c4_stmt, c4] = MakeInsert("comments", {"'4'", "'0'", "'3'"});
auto &&[c5_stmt, c5] = MakeInsert("comments", {"'5'", "'1'", "'3'"});

EXPECT_UPDATE(Execute(c0_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c1_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c2_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c3_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c4_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c5_stmt, &conn), 1);

// Validate get.
std::string get = MakeGDPRGet("commenters", "0");
EXPECT_EQ(Execute(get, &conn).ResultSets(),
(VV{(V{cr0}), (V{c0, "|1|NULL|0|", "|2|NULL|1|", "|3|NULL|2|",
"|4|NULL|3|", "|5|NULL|3|"})}));
}

} // namespace sqlengine
} // namespace shards
} // namespace k9db
47 changes: 46 additions & 1 deletion k9db/shards/sqlengine/gdpr_get_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ TEST_F(GDPRGetTest, TransitiveAccessorship) {
// Validate get for user with id 5.
std::string get1 = MakeGDPRGet("user", "5");
EXPECT_EQ(Execute(get1, &conn).ResultSets(),
(VV{(V{row4, row3}), (V{u1}), (V{}), (V{})}));
(VV{(V{row4, row3}), (V{u1}), (V{})}));

// Validate get for user with id 10.
std::string get2 = MakeGDPRGet("user", "10");
Expand Down Expand Up @@ -670,6 +670,51 @@ TEST_F(GDPRGetTest, ComplexVariableAccessorship) {
(VV{(V{frow1, frow2}), (V{grow1, grow2}), (V{d0}), (V{farow1, farow2})}));
}

TEST_F(GDPRGetTest, SelfFKTable) {
// Parse create table statements.
std::string commenters =
MakeCreate("commenters", {"commenterHex" STR PK}, true);
std::string comments =
MakeCreate("comments", {"commentHex" STR PK,
"commenterHex" STR OB "commenters(commenterHex)",
"parentHex" STR AB "comments(commentHex)"});

// Make a k9db connection.
Connection conn = CreateConnection();

// Create the tables.
EXPECT_SUCCESS(Execute(commenters, &conn));
EXPECT_SUCCESS(Execute(comments, &conn));

// Perform some inserts.
auto &&[cr0_stmt, cr0] = MakeInsert("commenters", {"'0'"});
auto &&[cr1_stmt, cr1] = MakeInsert("commenters", {"'1'"});
auto &&[cr2_stmt, cr2] = MakeInsert("commenters", {"'2'"});

EXPECT_UPDATE(Execute(cr0_stmt, &conn), 1);
EXPECT_UPDATE(Execute(cr1_stmt, &conn), 1);
EXPECT_UPDATE(Execute(cr2_stmt, &conn), 1);

auto &&[c0_stmt, c0] = MakeInsert("comments", {"'0'", "'0'", "NULL"});
auto &&[c1_stmt, c1] = MakeInsert("comments", {"'1'", "'0'", "'0'"});
auto &&[c2_stmt, c2] = MakeInsert("comments", {"'2'", "'1'", "'1'"});
auto &&[c3_stmt, c3] = MakeInsert("comments", {"'3'", "'2'", "'2'"});
auto &&[c4_stmt, c4] = MakeInsert("comments", {"'4'", "'0'", "'3'"});
auto &&[c5_stmt, c5] = MakeInsert("comments", {"'5'", "'1'", "'3'"});

EXPECT_UPDATE(Execute(c0_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c1_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c2_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c3_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c4_stmt, &conn), 1);
EXPECT_UPDATE(Execute(c5_stmt, &conn), 1);

// Validate get.
std::string get = MakeGDPRGet("commenters", "0");
EXPECT_EQ(Execute(get, &conn).ResultSets(),
(VV{(V{cr0}), (V{c0, c1, c2, c3, c4, c5})}));
}

} // namespace sqlengine
} // namespace shards
} // namespace k9db
12 changes: 10 additions & 2 deletions k9db/shards/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,20 @@ Table *SharderState::AddTable(Table &&table) {
}
case InfoType::TRANSITIVE: {
const TransitiveInfo &info = std::get<TransitiveInfo>(descriptor->info);
parent = &this->tables_.at(info.next_table);
if (info.next_table == table.table_name) {
parent = &table;
} else {
parent = &this->tables_.at(info.next_table);
}
break;
}
case InfoType::VARIABLE: {
const VariableInfo &info = std::get<VariableInfo>(descriptor->info);
parent = &this->tables_.at(info.origin_relation);
if (info.origin_relation == table.table_name) {
parent = &table;
} else {
parent = &this->tables_.at(info.origin_relation);
}
break;
}
}
Expand Down

0 comments on commit 027a9b9

Please sign in to comment.