diff --git a/src/gj.rs b/src/gj.rs index e6e5e0ae..a2150387 100644 --- a/src/gj.rs +++ b/src/gj.rs @@ -719,32 +719,13 @@ impl EGraph { } } - let is_rebuilding = cq.query.ruleset.to_string().contains("rebuilding_"); let do_seminaive = self.seminaive && !global_updated; // for the later atoms, we consider everything let mut timestamp_ranges = vec![0..u32::MAX; cq.query.atoms.len()]; if do_seminaive { - for (atom_i, atom) in cq.query.atoms.iter().enumerate() { + for (atom_i, _atom) in cq.query.atoms.iter().enumerate() { timestamp_ranges[atom_i] = timestamp..u32::MAX; - // For rebuilding, we have the invariant - // that new atoms are up-to-date w.r.t. - // old unionfind entries. - // So do the join for new unionfind entries - // and all the other atoms. - // TODO this hack fails on one benchmark, not sure why - let rebuilding_hack = false; - if rebuilding_hack { - let atom_has_parent = format!("{:?}", atom).contains("Parent_"); - if is_rebuilding && !atom_has_parent { - continue; - } else if is_rebuilding { - assert_eq!(cq.query.atoms.len(), 2); - timestamp_ranges = vec![0..u32::MAX; cq.query.atoms.len()]; - timestamp_ranges[atom_i] = timestamp..u32::MAX; - } - } - self.gj_for_atom(Some(atom_i), ×tamp_ranges, cq, &mut f); // now we can fix this atom to be "old stuff" only // range is half-open; timestamp is excluded diff --git a/src/lib.rs b/src/lib.rs index b311fbb3..57c29f52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -946,7 +946,7 @@ impl EGraph { let name = Symbol::from(name); let mut ctx = typecheck::Context::new(self); let (query0, action0) = ctx - .typecheck_query(&rule.body, &rule.head, ruleset) + .typecheck_query(&rule.body, &rule.head) .map_err(Error::TypeErrors)?; let query = self.compile_gj_query(query0, &ctx.types); let program = self @@ -1048,7 +1048,7 @@ impl EGraph { let converted_facts = facts.iter().map(|f| f.to_fact()).collect::>(); let empty_actions = vec![]; let (query0, _) = ctx - .typecheck_query(&converted_facts, &empty_actions, Symbol::from("")) + .typecheck_query(&converted_facts, &empty_actions) .map_err(Error::TypeErrors)?; let query = self.compile_gj_query(query0, &ctx.types); @@ -1221,7 +1221,7 @@ impl EGraph { let function_type = self .type_info() .lookup_user_func(func_name) - .unwrap_or_else(|| panic!("Unrecognzed function name {}", func_name)); + .unwrap_or_else(|| panic!("Unrecognized function name {}", func_name)); let func = self.functions.get_mut(&func_name).unwrap(); let mut filename = self.fact_directory.clone().unwrap_or_default(); diff --git a/src/sort/map.rs b/src/sort/map.rs index d270a703..c8abacb6 100644 --- a/src/sort/map.rs +++ b/src/sort/map.rs @@ -87,8 +87,22 @@ impl Sort for MapSort { result } - fn canonicalize(&self, _value: &mut Value, _unionfind: &UnionFind) -> bool { - false + fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool { + let maps = self.maps.lock().unwrap(); + let map = maps.get_index(value.bits as usize).unwrap(); + let mut changed = false; + let new_map: ValueMap = map + .iter() + .map(|(k, v)| { + let (mut k, mut v) = (*k, *v); + changed |= self.key.canonicalize(&mut k, unionfind); + changed |= self.value.canonicalize(&mut v, unionfind); + (k, v) + }) + .collect(); + drop(maps); + *value = new_map.store(self).unwrap(); + changed } fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { @@ -194,16 +208,9 @@ impl PrimitiveLike for MapRebuild { fn apply(&self, values: &[Value], egraph: &EGraph) -> Option { let maps = self.map.maps.lock().unwrap(); let map = maps.get_index(values[0].bits as usize).unwrap(); - let mut changed = false; let new_map: ValueMap = map .iter() - .map(|(k, v)| { - let (k, v) = (*k, *v); - let updated_k = egraph.find(k); - let updated_v = egraph.find(v); - changed |= updated_k != k || updated_v != v; - (updated_k, updated_v) - }) + .map(|(k, v)| (egraph.find(*k), egraph.find(*v))) .collect(); drop(maps); diff --git a/src/sort/set.rs b/src/sort/set.rs index c6a9b437..b25d7fe4 100644 --- a/src/sort/set.rs +++ b/src/sort/set.rs @@ -86,8 +86,21 @@ impl Sort for SetSort { result } - fn canonicalize(&self, _value: &mut Value, _unionfind: &UnionFind) -> bool { - false + fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool { + let sets = self.sets.lock().unwrap(); + let set = sets.get_index(value.bits as usize).unwrap(); + let mut changed = false; + let new_set: ValueSet = set + .iter() + .map(|e| { + let mut e = *e; + changed |= self.element.canonicalize(&mut e, unionfind); + e + }) + .collect(); + drop(sets); + *value = new_set.store(self).unwrap(); + changed } fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { diff --git a/src/sort/vec.rs b/src/sort/vec.rs index 1565f6b0..be97ae7d 100644 --- a/src/sort/vec.rs +++ b/src/sort/vec.rs @@ -87,7 +87,7 @@ impl Sort for VecSort { let vecs = self.vecs.lock().unwrap(); let vec = vecs.get_index(value.bits as usize).unwrap(); let mut changed = false; - let new_set: ValueVec = vec + let new_vec: ValueVec = vec .iter() .map(|e| { let mut e = *e; @@ -96,7 +96,7 @@ impl Sort for VecSort { }) .collect(); drop(vecs); - *value = new_set.store(self).unwrap(); + *value = new_vec.store(self).unwrap(); changed } @@ -215,17 +215,9 @@ impl PrimitiveLike for VecRebuild { fn apply(&self, values: &[Value], egraph: &EGraph) -> Option { let vec = ValueVec::load(&self.vec, &values[0]); - let mut changed = false; - let new_set: ValueVec = vec - .iter() - .map(|e| { - let updated = egraph.find(*e); - changed |= updated != *e; - updated - }) - .collect(); + let new_vec: ValueVec = vec.iter().map(|e| egraph.find(*e)).collect(); drop(vec); - Some(new_set.store(&self.vec).unwrap()) + Some(new_vec.store(&self.vec).unwrap()) } } struct VecOf { diff --git a/src/typecheck.rs b/src/typecheck.rs index 1b9012f4..d0606c37 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -52,8 +52,6 @@ impl std::fmt::Display for Atom { pub struct Query { pub atoms: Vec>, pub filters: Vec>, - // store the ruleset for hack on rebuilding - pub(crate) ruleset: Symbol, } impl std::fmt::Display for Query { @@ -131,7 +129,6 @@ impl<'a> Context<'a> { &mut self, facts: &'a [Fact], actions: &'a [Action], - ruleset: Symbol, ) -> Result<(Query, Vec), Vec> { for fact in facts { self.typecheck_fact(fact); @@ -215,7 +212,6 @@ impl<'a> Context<'a> { let mut query = Query { atoms: Default::default(), filters: Default::default(), - ruleset, }; let mut query_eclasses = HashSet::::default(); // Now we can fill in the nodes with the canonical leaves @@ -729,11 +725,15 @@ impl EGraph { fn perform_set( &mut self, table: Symbol, - args: &[Value], new_value: Value, stack: &mut Vec, ) -> Result<(), Error> { let function = self.functions.get_mut(&table).unwrap(); + + let new_len = stack.len() - function.schema.input.len(); + // TODO would be nice to use slice here + let args = &stack[new_len..]; + // We should only have canonical values here: omit the canonicalization step let old_value = function.get(args); @@ -757,6 +757,7 @@ impl EGraph { } }; if merged != old_value { + let args = &stack[new_len..]; let function = self.functions.get_mut(&table).unwrap(); function.insert(args, merged, self.timestamp); } @@ -864,10 +865,8 @@ impl EGraph { // except for setting the parent relation let new_value = stack.pop().unwrap(); let new_len = stack.len() - function.schema.input.len(); - // TODO would be nice to use slice here - let args = stack[new_len..].to_vec(); - self.perform_set(*f, &args, new_value, stack)?; + self.perform_set(*f, new_value, stack)?; stack.truncate(new_len) } Instruction::Union(arity) => { diff --git a/src/typechecking.rs b/src/typechecking.rs index d5057322..8aa2090b 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -20,7 +20,7 @@ pub struct TypeInfo { pub presort_names: HashSet, pub sorts: HashMap>, pub primitives: HashMap>, - func_types: HashMap, + pub func_types: HashMap, global_types: HashMap, pub local_types: HashMap>, } diff --git a/tests/levenshtein-distance.egg b/tests/levenshtein-distance.egg index 59f234e0..c80862cd 100644 --- a/tests/levenshtein-distance.egg +++ b/tests/levenshtein-distance.egg @@ -60,5 +60,11 @@ (run 100) +(extract (Unwrap Test1)) +(check (= Test1 (Num 3))) -(extract (Unwrap Test3)) \ No newline at end of file +(extract (Unwrap Test2)) +(check (= Test2 (Num 5))) + +(extract (Unwrap Test3)) +(check (= Test3 (Num 5))) \ No newline at end of file