Skip to content

Commit

Permalink
cleanup thanks to @yihozhang
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Sep 19, 2023
1 parent 83cac74 commit 6ab9f68
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 57 deletions.
21 changes: 1 addition & 20 deletions src/gj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,32 +719,13 @@ impl EGraph {
}
}

let is_rebuilding = cq.query.ruleset.to_string().contains("rebuilding_");
let do_seminaive = self.seminaive && !global_updated;
// for the later atoms, we consider everything
let mut timestamp_ranges = vec![0..u32::MAX; cq.query.atoms.len()];
if do_seminaive {
for (atom_i, atom) in cq.query.atoms.iter().enumerate() {
for (atom_i, _atom) in cq.query.atoms.iter().enumerate() {
timestamp_ranges[atom_i] = timestamp..u32::MAX;

// For rebuilding, we have the invariant
// that new atoms are up-to-date w.r.t.
// old unionfind entries.
// So do the join for new unionfind entries
// and all the other atoms.
// TODO this hack fails on one benchmark, not sure why
let rebuilding_hack = false;
if rebuilding_hack {
let atom_has_parent = format!("{:?}", atom).contains("Parent_");
if is_rebuilding && !atom_has_parent {
continue;
} else if is_rebuilding {
assert_eq!(cq.query.atoms.len(), 2);
timestamp_ranges = vec![0..u32::MAX; cq.query.atoms.len()];
timestamp_ranges[atom_i] = timestamp..u32::MAX;
}
}

self.gj_for_atom(Some(atom_i), &timestamp_ranges, cq, &mut f);
// now we can fix this atom to be "old stuff" only
// range is half-open; timestamp is excluded
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ impl EGraph {
let name = Symbol::from(name);
let mut ctx = typecheck::Context::new(self);
let (query0, action0) = ctx
.typecheck_query(&rule.body, &rule.head, ruleset)
.typecheck_query(&rule.body, &rule.head)
.map_err(Error::TypeErrors)?;
let query = self.compile_gj_query(query0, &ctx.types);
let program = self
Expand Down Expand Up @@ -1048,7 +1048,7 @@ impl EGraph {
let converted_facts = facts.iter().map(|f| f.to_fact()).collect::<Vec<Fact>>();
let empty_actions = vec![];
let (query0, _) = ctx
.typecheck_query(&converted_facts, &empty_actions, Symbol::from(""))
.typecheck_query(&converted_facts, &empty_actions)
.map_err(Error::TypeErrors)?;
let query = self.compile_gj_query(query0, &ctx.types);

Expand Down Expand Up @@ -1221,7 +1221,7 @@ impl EGraph {
let function_type = self
.type_info()
.lookup_user_func(func_name)
.unwrap_or_else(|| panic!("Unrecognzed function name {}", func_name));
.unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
let func = self.functions.get_mut(&func_name).unwrap();

let mut filename = self.fact_directory.clone().unwrap_or_default();
Expand Down
27 changes: 17 additions & 10 deletions src/sort/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,22 @@ impl Sort for MapSort {
result
}

fn canonicalize(&self, _value: &mut Value, _unionfind: &UnionFind) -> bool {
false
fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool {
let maps = self.maps.lock().unwrap();
let map = maps.get_index(value.bits as usize).unwrap();
let mut changed = false;
let new_map: ValueMap = map
.iter()
.map(|(k, v)| {
let (mut k, mut v) = (*k, *v);
changed |= self.key.canonicalize(&mut k, unionfind);
changed |= self.value.canonicalize(&mut v, unionfind);
(k, v)
})
.collect();
drop(maps);
*value = new_map.store(self).unwrap();
changed
}

fn register_primitives(self: Arc<Self>, typeinfo: &mut TypeInfo) {
Expand Down Expand Up @@ -194,16 +208,9 @@ impl PrimitiveLike for MapRebuild {
fn apply(&self, values: &[Value], egraph: &EGraph) -> Option<Value> {
let maps = self.map.maps.lock().unwrap();
let map = maps.get_index(values[0].bits as usize).unwrap();
let mut changed = false;
let new_map: ValueMap = map
.iter()
.map(|(k, v)| {
let (k, v) = (*k, *v);
let updated_k = egraph.find(k);
let updated_v = egraph.find(v);
changed |= updated_k != k || updated_v != v;
(updated_k, updated_v)
})
.map(|(k, v)| (egraph.find(*k), egraph.find(*v)))
.collect();

drop(maps);
Expand Down
17 changes: 15 additions & 2 deletions src/sort/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,21 @@ impl Sort for SetSort {
result
}

fn canonicalize(&self, _value: &mut Value, _unionfind: &UnionFind) -> bool {
false
fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool {
let sets = self.sets.lock().unwrap();
let set = sets.get_index(value.bits as usize).unwrap();
let mut changed = false;
let new_set: ValueSet = set
.iter()
.map(|e| {
let mut e = *e;
changed |= self.element.canonicalize(&mut e, unionfind);
e
})
.collect();
drop(sets);
*value = new_set.store(self).unwrap();
changed
}

fn register_primitives(self: Arc<Self>, typeinfo: &mut TypeInfo) {
Expand Down
16 changes: 4 additions & 12 deletions src/sort/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl Sort for VecSort {
let vecs = self.vecs.lock().unwrap();
let vec = vecs.get_index(value.bits as usize).unwrap();
let mut changed = false;
let new_set: ValueVec = vec
let new_vec: ValueVec = vec
.iter()
.map(|e| {
let mut e = *e;
Expand All @@ -96,7 +96,7 @@ impl Sort for VecSort {
})
.collect();
drop(vecs);
*value = new_set.store(self).unwrap();
*value = new_vec.store(self).unwrap();
changed
}

Expand Down Expand Up @@ -215,17 +215,9 @@ impl PrimitiveLike for VecRebuild {
fn apply(&self, values: &[Value], egraph: &EGraph) -> Option<Value> {
let vec = ValueVec::load(&self.vec, &values[0]);

let mut changed = false;
let new_set: ValueVec = vec
.iter()
.map(|e| {
let updated = egraph.find(*e);
changed |= updated != *e;
updated
})
.collect();
let new_vec: ValueVec = vec.iter().map(|e| egraph.find(*e)).collect();
drop(vec);
Some(new_set.store(&self.vec).unwrap())
Some(new_vec.store(&self.vec).unwrap())
}
}
struct VecOf {
Expand Down
15 changes: 7 additions & 8 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ impl<T: std::fmt::Display> std::fmt::Display for Atom<T> {
pub struct Query {
pub atoms: Vec<Atom<Symbol>>,
pub filters: Vec<Atom<Primitive>>,
// store the ruleset for hack on rebuilding
pub(crate) ruleset: Symbol,
}

impl std::fmt::Display for Query {
Expand Down Expand Up @@ -131,7 +129,6 @@ impl<'a> Context<'a> {
&mut self,
facts: &'a [Fact],
actions: &'a [Action],
ruleset: Symbol,
) -> Result<(Query, Vec<Action>), Vec<TypeError>> {
for fact in facts {
self.typecheck_fact(fact);
Expand Down Expand Up @@ -215,7 +212,6 @@ impl<'a> Context<'a> {
let mut query = Query {
atoms: Default::default(),
filters: Default::default(),
ruleset,
};
let mut query_eclasses = HashSet::<Id>::default();
// Now we can fill in the nodes with the canonical leaves
Expand Down Expand Up @@ -729,11 +725,15 @@ impl EGraph {
fn perform_set(
&mut self,
table: Symbol,
args: &[Value],
new_value: Value,
stack: &mut Vec<Value>,
) -> Result<(), Error> {
let function = self.functions.get_mut(&table).unwrap();

let new_len = stack.len() - function.schema.input.len();
// TODO would be nice to use slice here
let args = &stack[new_len..];

// We should only have canonical values here: omit the canonicalization step
let old_value = function.get(args);

Expand All @@ -757,6 +757,7 @@ impl EGraph {
}
};
if merged != old_value {
let args = &stack[new_len..];
let function = self.functions.get_mut(&table).unwrap();
function.insert(args, merged, self.timestamp);
}
Expand Down Expand Up @@ -864,10 +865,8 @@ impl EGraph {
// except for setting the parent relation
let new_value = stack.pop().unwrap();
let new_len = stack.len() - function.schema.input.len();
// TODO would be nice to use slice here
let args = stack[new_len..].to_vec();

self.perform_set(*f, &args, new_value, stack)?;
self.perform_set(*f, new_value, stack)?;
stack.truncate(new_len)
}
Instruction::Union(arity) => {
Expand Down
2 changes: 1 addition & 1 deletion src/typechecking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct TypeInfo {
pub presort_names: HashSet<Symbol>,
pub sorts: HashMap<Symbol, Arc<dyn Sort>>,
pub primitives: HashMap<Symbol, Vec<Primitive>>,
func_types: HashMap<Symbol, FuncType>,
pub func_types: HashMap<Symbol, FuncType>,
global_types: HashMap<Symbol, ArcSort>,
pub local_types: HashMap<CommandId, HashMap<Symbol, ArcSort>>,
}
Expand Down
8 changes: 7 additions & 1 deletion tests/levenshtein-distance.egg
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,11 @@

(run 100)

(extract (Unwrap Test1))
(check (= Test1 (Num 3)))

(extract (Unwrap Test3))
(extract (Unwrap Test2))
(check (= Test2 (Num 5)))

(extract (Unwrap Test3))
(check (= Test3 (Num 5)))

0 comments on commit 6ab9f68

Please sign in to comment.