From 6d84fc70b67e3a3c3c7783ac28a17f957fb4174a Mon Sep 17 00:00:00 2001 From: Matthew A Johnson Date: Fri, 19 Jan 2024 17:29:17 +0000 Subject: [PATCH] Fixes for regressions in local variable behavior. (#98) As a result of the optimization passes, some incorrect behavior was introduced for local variables. This commit restores correct functionality. It also fixes a crash when accessing certain nodes with an undefined key. Signed-off-by: Matthew Johnson --- CHANGELOG | 15 ++ VERSION | 2 +- examples/rust/Cargo.toml | 2 +- src/internal.hh | 1 + src/passes/init.cc | 206 ++++++++++++++++++++++------ src/passes/locals.cc | 8 ++ src/passes/rulebody.cc | 1 - src/rego.cc | 1 + src/unifier.cc | 10 +- tests/regocpp.yaml | 28 ++++ wrappers/python/docs/source/conf.py | 2 +- wrappers/rust/regorust/Cargo.toml | 2 +- 12 files changed, 227 insertions(+), 51 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index c8eeac14..772412e2 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,20 @@ # Changelog +## 2024-01-19 - Version 0.3.11 +Minor improvements and bug fixes. + +**New Features** +- Updated to more recent Trieste version +- More sophisticated logging + +**Bug fixes** +- Comprehensions over local variables were not properly capturing the local (regression due to optimization) +- Local variable initializations were order-dependent (regression due to optimization) +- In some circumstances, indexing the data object with an undefined key caused a segfault. + +**Other** +- Various CI changes due to issues with Github actions. + ## 2023-09-21 - Version 0.3.10 Instrumentation and optimization. diff --git a/VERSION b/VERSION index 81de5c57..99a89b94 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.10 \ No newline at end of file +0.3.11 \ No newline at end of file diff --git a/examples/rust/Cargo.toml b/examples/rust/Cargo.toml index 8481795b..46a85578 100644 --- a/examples/rust/Cargo.toml +++ b/examples/rust/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -regorust = "0.3.10" +regorust = "0.3.11" clap = { version = "4.0", features = ["derive"] } \ No newline at end of file diff --git a/src/internal.hh b/src/internal.hh index 54b85eec..2043ad8a 100644 --- a/src/internal.hh +++ b/src/internal.hh @@ -146,6 +146,7 @@ namespace rego PassDef explicit_enums(); PassDef body_locals(const BuiltIns& builtins); PassDef value_locals(const BuiltIns& builtins); + PassDef compr_locals(const BuiltIns& builtins); PassDef rules_to_compr(); PassDef compr(); PassDef absolute_refs(); diff --git a/src/passes/init.cc b/src/passes/init.cc index 7e1f3fd5..7a5c0b55 100644 --- a/src/passes/init.cc +++ b/src/passes/init.cc @@ -1,14 +1,26 @@ #include "internal.hh" +#include +#include +#include +#include + namespace { using namespace rego; using namespace wf::ops; + struct InitSide + { + std::set vars; + std::set inits; + }; + struct InitInfo { - std::set lhs_vars; - std::set rhs_vars; + std::size_t index; + InitSide lhs; + InitSide rhs; }; Node to_init( @@ -37,11 +49,12 @@ namespace return LiteralInit << lhs_vars << rhs_vars << (AssignInfix << lhs << rhs); } - void vars_from(Node node, std::set& vars) + void inits_from( + Node node, const std::set& locals, std::set& inits) { - if (node->type() == Var) + if (node->type() == Var && contains(locals, node->location())) { - vars.insert(node->location()); + inits.insert(node->location()); return; } @@ -78,13 +91,102 @@ namespace for (Node child : *node) { - vars_from(child, vars); + inits_from(child, locals, inits); + } + } + + void vars_from( + Node node, const std::set& locals, std::set& vars) + { + if (node->type() == Var && contains(locals, node->location())) + { + vars.insert(node->location()); + } + + for (Node child : *node) + { + vars_from(child, locals, vars); + } + } + + InitSide side_from(Node node, const std::set& locals) + { + InitSide side; + inits_from(node, locals, side.inits); + vars_from(node, locals, side.vars); + return side; + } + + bool any_compiler_inits(const InitSide& lhs) + { + return std::any_of(lhs.inits.begin(), lhs.inits.end(), [](auto& loc) { + std::string name = loc.str(); + return starts_with(name, "unify$") || starts_with(name, "out$") || + starts_with(name, "value$"); + }); + } + + void remove_locals( + std::deque& init_deque, const std::set& to_remove) + { + std::size_t count = init_deque.size(); + for (std::size_t i = 0; i < count; ++i) + { + InitInfo& init_stmt = init_deque.front(); + for (auto& loc : to_remove) + { + init_stmt.lhs.vars.erase(loc); + init_stmt.lhs.inits.erase(loc); + init_stmt.rhs.vars.erase(loc); + init_stmt.rhs.inits.erase(loc); + } + if (!init_stmt.lhs.inits.empty() || !init_stmt.rhs.inits.empty()) + { + init_deque.push_back(init_stmt); + } + + init_deque.pop_front(); } } + std::vector sort_init_stmts( + const std::set& locals, std::deque& init_deque) + { + std::set initialized; + std::vector init_stmts; + while (!init_deque.empty() && initialized != locals) + { + // find all strict init statements + auto it = + std::find_if(init_deque.begin(), init_deque.end(), [](auto& init_stmt) { + return init_stmt.lhs.vars.empty() || init_stmt.rhs.vars.empty(); + }); + + if (it == init_deque.end()) + { + // we have a cycle, so we use the first statement + it = init_deque.begin(); + init_stmts.push_back(*it); + } + else + { + init_stmts.push_back(*it); + } + + std::set to_remove; + to_remove.insert(it->lhs.inits.begin(), it->lhs.inits.end()); + to_remove.insert(it->rhs.inits.begin(), it->rhs.inits.end()); + init_deque.erase(it); + remove_locals(init_deque, to_remove); + initialized.insert(to_remove.begin(), to_remove.end()); + } + + return init_stmts; + } + void find_init_stmts(Node unifybody, std::set& locals) { - // gather all locals + std::deque potential_init_stmts; for (std::size_t i = 0; i < unifybody->size(); ++i) { Node stmt = unifybody->at(i); @@ -95,15 +197,6 @@ namespace else if (stmt->type() == LiteralEnum) { locals.erase((stmt / Item)->location()); - find_init_stmts(stmt / UnifyBody, locals); - } - else if (stmt->type() == LiteralWith) - { - find_init_stmts(stmt / UnifyBody, locals); - } - else if (stmt->type() == LiteralNot) - { - find_init_stmts(stmt / UnifyBody, locals); } else if (stmt->type() == Literal) { @@ -115,42 +208,65 @@ namespace Node lhs = expr->front(); Node rhs = expr->back(); - std::set lhs_vars; - vars_from(lhs, lhs_vars); - std::set lhs_found; - std::set_intersection( - lhs_vars.begin(), - lhs_vars.end(), - locals.begin(), - locals.end(), - std::inserter(lhs_found, lhs_found.begin())); - - std::set rhs_vars; - vars_from(rhs, rhs_vars); - std::set rhs_found; - std::set_intersection( - rhs_vars.begin(), - rhs_vars.end(), - locals.begin(), - locals.end(), - std::inserter(rhs_found, rhs_found.begin())); - - if (lhs_found.empty() && rhs_found.empty()) - { - continue; - } - for (auto& loc : lhs_found) + InitSide lhs_side = side_from(lhs, locals); + InitSide rhs_side = side_from(rhs, locals); + + if (any_compiler_inits(lhs_side)) { - locals.erase(loc); + // compiler statements will never be right-assign, so we can + // use this fact later to help resolve some ambiguities + rhs_side.inits.clear(); } - for (auto& loc : rhs_found) + if (lhs_side.inits.empty() && rhs_side.inits.empty()) { - locals.erase(loc); + continue; } - unifybody->replace_at(i, to_init(lhs, lhs_found, rhs, rhs_found)); + potential_init_stmts.push_back({i, lhs_side, rhs_side}); + } + } + + std::vector init_stmts = + sort_init_stmts(locals, potential_init_stmts); + for (std::size_t i = 0; i < init_stmts.size(); ++i) + { + InitInfo& init_stmt = init_stmts[i]; + Node expr = unifybody->at(init_stmt.index)->front()->front(); + + Node lhs = expr->front(); + Node rhs = expr->back(); + + for (auto& loc : init_stmt.lhs.inits) + { + locals.erase(loc); + } + + for (auto& loc : init_stmt.rhs.inits) + { + locals.erase(loc); + } + + unifybody->replace_at( + init_stmt.index, + to_init(lhs, init_stmt.lhs.inits, rhs, init_stmt.rhs.inits)); + } + + // where appropriate, recurse with the updated locals + for (Node stmt : *unifybody) + { + if (stmt->type() == LiteralEnum) + { + find_init_stmts(stmt / UnifyBody, locals); + } + else if (stmt->type() == LiteralWith) + { + find_init_stmts(stmt / UnifyBody, locals); + } + else if (stmt->type() == LiteralNot) + { + find_init_stmts(stmt / UnifyBody, locals); } } } diff --git a/src/passes/locals.cc b/src/passes/locals.cc index 18ccac7a..eeabe3fb 100644 --- a/src/passes/locals.cc +++ b/src/passes/locals.cc @@ -134,6 +134,14 @@ namespace rego RuleObj, [builtins](Node n) { return preprocess_body(n, builtins); }); locals.pre( RuleSet, [builtins](Node n) { return preprocess_body(n, builtins); }); + + return locals; + } + + PassDef compr_locals(const BuiltIns& builtins) + { + PassDef locals = { + "compr_locals", wf_pass_locals, dir::bottomup | dir::once}; locals.pre( ArrayCompr, [builtins](Node n) { return preprocess_body(n, builtins); }); locals.pre( diff --git a/src/passes/rulebody.cc b/src/passes/rulebody.cc index bc84a1cf..486c89bf 100644 --- a/src/passes/rulebody.cc +++ b/src/passes/rulebody.cc @@ -138,7 +138,6 @@ namespace rego << (T(Var)[Lhs] * T(Var)[Rhs] * T(UnifyBody)[UnifyBody])) >> [](Match& _) { ACTION(); - logging::Debug() << "enum"; Location value = _.fresh({"value"}); return Seq << (Lift << UnifyBody << (Local << (Var ^ value) << Undefined)) diff --git a/src/rego.cc b/src/rego.cc index f06e2f25..934e5879 100644 --- a/src/rego.cc +++ b/src/rego.cc @@ -30,6 +30,7 @@ namespace rego explicit_enums(), body_locals(builtins), value_locals(builtins), + compr_locals(builtins), rules_to_compr(), compr(), absolute_refs(), diff --git a/src/unifier.cc b/src/unifier.cc index 2c008dcf..dbaa3f84 100644 --- a/src/unifier.cc +++ b/src/unifier.cc @@ -975,7 +975,15 @@ namespace rego } else { - auto maybe_nodes = Resolver::apply_access(container, args[1]->node()); + Node index = args[1]->node(); + if (index->type() == Undefined) + { + values.push_back( + ValueDef::create(var, Undefined ^ "undefined", sources)); + return values; + } + + auto maybe_nodes = Resolver::apply_access(container, index); if (maybe_nodes) { Nodes defs = maybe_nodes.value(); diff --git a/tests/regocpp.yaml b/tests/regocpp.yaml index 81ddf272..88f5ada3 100644 --- a/tests/regocpp.yaml +++ b/tests/regocpp.yaml @@ -1094,3 +1094,31 @@ cases: query: data.every_some.output = x want_result: - x: true +- note: regocpp/bug95 + modules: + - | + package test + + x = c { + a = b + b = c + a = 12 + } + query: data.test.x = x + want_result: + - x: 12 +- note: regocpp/bug97 + modules: + - | + package test + + x = y { + a = [1, 2, 3] + y = {z | z = a[_]} + } + query: data.test.x = x + want_result: + - x: + - 1 + - 2 + - 3 diff --git a/wrappers/python/docs/source/conf.py b/wrappers/python/docs/source/conf.py index 85732524..0b7cc879 100644 --- a/wrappers/python/docs/source/conf.py +++ b/wrappers/python/docs/source/conf.py @@ -9,7 +9,7 @@ project = 'regopy' copyright = '2023, Microsoft' author = 'Microsoft' -release = '0.3.10' +release = '0.3.11' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/wrappers/rust/regorust/Cargo.toml b/wrappers/rust/regorust/Cargo.toml index 94b8a3f5..67352ac3 100644 --- a/wrappers/rust/regorust/Cargo.toml +++ b/wrappers/rust/regorust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "regorust" -version = "0.3.10" +version = "0.3.11" edition = "2021" description = "Rust bindings for the rego-cpp Rego compiler and interpreter" license = "MIT"