Skip to content

Commit

Permalink
Merge pull request #233 from oflatt/oflatt-fast-terms
Browse files Browse the repository at this point in the history
Encoding equality saturation, including rebuilding and the union-find
  • Loading branch information
oflatt authored Oct 20, 2023
2 parents c7405d8 + 2337879 commit c741636
Show file tree
Hide file tree
Showing 18 changed files with 1,022 additions and 182 deletions.
28 changes: 27 additions & 1 deletion src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,22 @@ fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
}
}

/// The Desugar struct stores all the state needed
/// during desugaring a program.
/// While desugaring doesn't need type information, it
/// needs to know what global variables exist.
/// It also needs to know what functions are primitives
/// (it uses the [`TypeInfo`] for that.
/// After desugaring, typechecking happens and the
/// type_info field is used for that.
pub struct Desugar {
next_fresh: usize,
next_command_id: usize,
pub(crate) parser: ast::parse::ProgramParser,
// Store the parser because it takes some time
// on startup for some reason
parser: ast::parse::ProgramParser,
pub(crate) expr_parser: ast::parse::ExprParser,
pub(crate) action_parser: ast::parse::ActionParser,
// TODO fix getting fresh names using modules
pub(crate) number_underscores: usize,
pub(crate) global_variables: HashSet<Symbol>,
Expand All @@ -429,6 +441,8 @@ impl Default for Desugar {
next_command_id: Default::default(),
// these come from lalrpop and don't have default impls
parser: ast::parse::ProgramParser::new(),
expr_parser: ast::parse::ExprParser::new(),
action_parser: ast::parse::ActionParser::new(),
number_underscores: 3,
global_variables: Default::default(),
type_info: TypeInfo::default(),
Expand Down Expand Up @@ -694,6 +708,8 @@ impl Clone for Desugar {
next_fresh: self.next_fresh,
next_command_id: self.next_command_id,
parser: ast::parse::ProgramParser::new(),
expr_parser: ast::parse::ExprParser::new(),
action_parser: ast::parse::ActionParser::new(),
number_underscores: self.number_underscores,
global_variables: self.global_variables.clone(),
type_info: self.type_info.clone(),
Expand Down Expand Up @@ -817,4 +833,14 @@ impl Desugar {
unextractable: fdecl.unextractable,
})]
}

/// Get the name of the parent table for a sort
/// for the term encoding (not related to desugaring)
pub(crate) fn parent_name(&self, sort: Symbol) -> Symbol {
Symbol::from(format!(
"{}_Parent{}",
sort,
"_".repeat(self.number_underscores)
))
}
}
33 changes: 28 additions & 5 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ pub enum Schedule {
Sequence(Vec<Schedule>),
}

impl Schedule {
pub fn saturate(self) -> Self {
Schedule::Saturate(Box::new(self))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NormSchedule {
Saturate(Box<NormSchedule>),
Expand Down Expand Up @@ -989,6 +995,18 @@ impl NormFact {
}
}

pub(crate) fn map_use(&self, fvar: &mut impl FnMut(Symbol) -> Expr) -> Fact {
match self {
NormFact::AssignVar(lhs, rhs) => Fact::Eq(vec![Expr::Var(*lhs), fvar(*rhs)]),
NormFact::ConstrainEq(lhs, rhs) => Fact::Eq(vec![fvar(*lhs), fvar(*rhs)]),
NormFact::Compute(lhs, NormExpr::Call(op, children)) => Fact::Eq(vec![
fvar(*lhs),
Expr::Call(*op, children.iter().cloned().map(fvar).collect()),
]),
NormFact::AssignLit(..) | NormFact::Assign(..) => self.to_fact(),
}
}

pub(crate) fn map_def_use(&self, fvar: &mut impl FnMut(Symbol, bool) -> Symbol) -> NormFact {
match self {
NormFact::Assign(symbol, expr) => {
Expand Down Expand Up @@ -1071,11 +1089,14 @@ pub enum Action {
/// (extract (Num 2)); Extracts Num 1
/// ```
Union(Expr, Expr),
/// `extract` the lowest-cost term equal to the one given.
/// Also, extract `n` variants of the term by selecting different
/// terms with unique constructors and children.
/// When `n` is zero, just extract the lowest-cost term.
/// See [`Command::QueryExtract`] for more details.
/// `extract` a datatype from the egraph, choosing
/// the smallest representative.
/// By default, each constructor costs 1 to extract
/// (common subexpressions are not shared in the cost
/// model).
/// The second argument is the number of variants to
/// extract, picking different terms in the
/// same equivalence class.
Extract(Expr, Expr),
Panic(String),
Expr(Expr),
Expand Down Expand Up @@ -1333,6 +1354,8 @@ impl NormRule {
let substituted = new_expr.subst(subst);

// TODO sometimes re-arranging actions is bad
// this is because actions can fail
// halfway through in the current semantics
if substituted.ast_size() > 1 {
head.push(Action::Let(*symbol, substituted));
} else {
Expand Down
29 changes: 11 additions & 18 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@ pub struct Extractor<'a> {
}

impl EGraph {
pub fn value_to_id(&self, value: Value) -> Option<(Symbol, Id)> {
if let Some(sort) = self.get_sort(&value) {
if sort.is_eq_sort() {
let id = Id::from(value.bits as usize);
return Some((sort.name(), self.find(id)));
}
}
None
}

pub fn extract(&self, value: Value, termdag: &mut TermDag, arcsort: &ArcSort) -> (Cost, Term) {
let extractor = Extractor::new(self, termdag);
extractor
Expand All @@ -45,7 +35,7 @@ impl EGraph {
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find(input)))
.map(|input| extractor.costs.get(&extractor.find_id(*input)))
.collect::<Vec<_>>()
);
}
Expand All @@ -62,8 +52,7 @@ impl EGraph {
limit: usize,
termdag: &mut TermDag,
) -> Vec<Term> {
let (tag, id) = self.value_to_id(value).unwrap();
let output_value = &Value::from_id(tag, id);
let output_value = self.find(value);
let ext = &Extractor::new(self, termdag);
ext.ctors
.iter()
Expand All @@ -77,7 +66,7 @@ impl EGraph {
func.nodes
.iter()
.filter_map(|(inputs, output)| {
(&output.value == output_value).then(|| {
(output.value == output_value).then(|| {
let node = Node { sym, inputs };
ext.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
Expand Down Expand Up @@ -130,7 +119,7 @@ impl<'a> Extractor<'a> {
sort: &ArcSort,
) -> Option<(Cost, Term)> {
if sort.is_eq_sort() {
let id = self.find(&value);
let id = self.find_id(value);
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
Expand All @@ -156,8 +145,12 @@ impl<'a> Extractor<'a> {
Some((terms, cost))
}

fn find(&self, value: &Value) -> Id {
self.egraph.find(Id::from(value.bits as usize))
fn find(&self, value: Value) -> Value {
self.egraph.find(value)
}

fn find_id(&self, value: Value) -> Id {
Id::from(self.find(value).bits as usize)
}

fn find_costs(&mut self, termdag: &mut TermDag) {
Expand All @@ -174,7 +167,7 @@ impl<'a> Extractor<'a> {
{
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

let id = self.find(&output.value);
let id = self.find_id(output.value);
match self.costs.entry(id) {
Entry::Vacant(e) => {
did_something = true;
Expand Down
3 changes: 2 additions & 1 deletion src/gj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ impl<'b> Context<'b> {
})
}

if let Some(res) = prim.apply(&values) {
if let Some(res) = prim.apply(&values, self.egraph) {
match out {
AtomTerm::Var(v) => {
let i = self.query.vars.get_index_of(v).unwrap();
Expand Down Expand Up @@ -725,6 +725,7 @@ impl EGraph {
if do_seminaive {
for (atom_i, _atom) in cq.query.atoms.iter().enumerate() {
timestamp_ranges[atom_i] = timestamp..u32::MAX;

self.gj_for_atom(Some(atom_i), &timestamp_ranges, cq, &mut f);
// now we can fix this atom to be "old stuff" only
// range is half-open; timestamp is excluded
Expand Down
Loading

0 comments on commit c741636

Please sign in to comment.