Skip to content

Commit

Permalink
Refactor sort extraction API (#495)
Browse files Browse the repository at this point in the history
* Refactor sort extraction API

* Remove make_expr
  • Loading branch information
RiscInside authored Dec 13, 2024
1 parent 3df83aa commit 197103d
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 155 deletions.
4 changes: 2 additions & 2 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ impl<'a> Extractor<'a> {
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
let (cost, node) = sort.extract_expr(self.egraph, value, self, termdag)?;
Some((cost, termdag.expr_to_term(&node)))
let (cost, node) = sort.extract_term(self.egraph, value, self, termdag)?;
Some((cost, node))
}
}

Expand Down
13 changes: 11 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,12 @@ impl EGraph {
if a_type.is_eq_sort() {
children.push(extractor.find_best(a, &mut termdag, a_type).unwrap().1);
} else {
children.push(termdag.expr_to_term(&a_type.make_expr(self, a).1));
children.push(
a_type
.extract_term(self, a, &extractor, &mut termdag)
.unwrap()
.1,
)
};
}

Expand All @@ -721,7 +726,11 @@ impl EGraph {
.unwrap()
.1
} else {
termdag.expr_to_term(&schema.output.make_expr(self, out.value).1)
schema
.output
.extract_term(self, out.value, &extractor, &mut termdag)
.unwrap()
.1
};
terms.push((termdag.app(sym, children), out));
}
Expand Down
13 changes: 11 additions & 2 deletions src/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use ordered_float::NotNan;
use std::collections::VecDeque;

use crate::{util::HashMap, ArcSort, EGraph, Function, Symbol, TupleOutput, Value};
use crate::{
extract::Extractor, util::HashMap, ArcSort, EGraph, Function, Symbol, TermDag, TupleOutput,
Value,
};

pub struct SerializeConfig {
// Maximumum number of functions to include in the serialized graph, any after this will be discarded
Expand Down Expand Up @@ -312,7 +315,13 @@ impl EGraph {
let op = if sort.is_container_sort() {
sort.serialized_name(value).to_string()
} else {
sort.make_expr(self, *value).1.to_string()
let mut termdag = TermDag::default();
let extractor = Extractor::new(self, &mut termdag);
let (_, term) = sort
.extract_term(self, *value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized");

termdag.term_to_expr(&term).to_string()
};
egraph.nodes.insert(
node_id.clone(),
Expand Down
21 changes: 10 additions & 11 deletions src/sort/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,20 @@ impl Sort for BigIntSort {
add_primitives!(eg, "from-string" = |a: Symbol| -> Opt<Z> { a.as_str().parse::<Z>().ok() });
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
_extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Term)> {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

let bigint = Z::load(self, &value);
(
1,
Expr::call_no_span(
"from-string",
vec![GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::String(bigint.to_string().into()),
)],
),
)

let as_string = termdag.lit(Literal::String(bigint.to_string().into()));
Some((1, termdag.app("from-string".into(), vec![as_string])))
}
}

Expand Down
39 changes: 17 additions & 22 deletions src/sort/bigrat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,30 @@ impl Sort for BigRatSort {
add_primitives!(eg, ">=" = |a: Q, b: Q| -> Opt { if a >= b {Some(())} else {None} });
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
_extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Term)> {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

let rat = Q::load(self, &value);
let numer = rat.numer();
let denom = rat.denom();
(

let numer_as_string = termdag.lit(Literal::String(numer.to_string().into()));
let denom_as_string = termdag.lit(Literal::String(denom.to_string().into()));

let numer_term = termdag.app("from-string".into(), vec![numer_as_string]);
let denom_term = termdag.app("from-string".into(), vec![denom_as_string]);

Some((
1,
Expr::call_no_span(
"bigrat",
vec![
Expr::call_no_span(
"from-string",
vec![GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::String(numer.to_string().into()),
)],
),
Expr::call_no_span(
"from-string",
vec![GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::String(denom.to_string().into()),
)],
),
],
),
)
termdag.app("bigrat".into(), vec![numer_term, denom_term]),
))
}
}

Expand Down
13 changes: 8 additions & 5 deletions src/sort/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ impl Sort for BoolSort {
add_primitives!(eg, "=>" = |a: bool, b: bool| -> bool { !a || b });
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
_extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Term)> {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

(
1,
GenericExpr::Lit(DUMMY_SPAN.clone(), Literal::Bool(value.bits > 0)),
)
Some((1, termdag.lit(Literal::Bool(value.bits > 0))))
}
}

Expand Down
17 changes: 10 additions & 7 deletions src/sort/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,20 @@ impl Sort for F64Sort {

}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
_extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Term)> {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

(
Some((
1,
GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::Float(OrderedFloat(f64::from_bits(value.bits))),
),
)
termdag.lit(Literal::Float(OrderedFloat(f64::from_bits(value.bits)))),
))
}
}

Expand Down
20 changes: 5 additions & 15 deletions src/sort/fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,34 +185,24 @@ impl Sort for FunctionSort {
});
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
) -> Option<(Cost, Term)> {
let ValueFunction(name, inputs) = ValueFunction::load(self, &value);
let (cost, args) = inputs.into_iter().try_fold(
(
1usize,
vec![GenericExpr::Lit(DUMMY_SPAN.clone(), Literal::String(name))],
),
(1usize, vec![termdag.lit(Literal::String(name))]),
|(cost, mut args), (sort, value)| {
let (new_cost, term) = extractor.find_best(value, termdag, &sort)?;
args.push(termdag.term_to_expr(&term));
args.push(term);
Some((cost.saturating_add(new_cost), args))
},
)?;

Some((cost, Expr::call_no_span("unstable-fn", args)))
Some((cost, termdag.app("unstable-fn".into(), args)))
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/sort/i64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ impl Sort for I64Sort {

}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

(
1,
GenericExpr::Lit(DUMMY_SPAN.clone(), Literal::Int(value.bits as _)),
)
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
_extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Term)> {
Some((1, termdag.lit(Literal::Int(value.bits as _))))
}
}

Expand Down
20 changes: 5 additions & 15 deletions src/sort/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,33 +168,23 @@ impl Sort for MapSort {
});
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
) -> Option<(Cost, Term)> {
let map = ValueMap::load(self, &value);
let mut expr = Expr::call_no_span("map-empty", []);
let mut term = termdag.app("map-empty".into(), vec![]);
let mut cost = 0usize;
for (k, v) in map.iter().rev() {
let k = extractor.find_best(*k, termdag, &self.key)?;
let v = extractor.find_best(*v, termdag, &self.value)?;
cost = cost.saturating_add(k.0).saturating_add(v.0);
expr = Expr::call_no_span(
"map-insert",
[expr, termdag.term_to_expr(&k.1), termdag.term_to_expr(&v.1)],
)
term = termdag.app("map-insert".into(), vec![term, k.1, v.1]);
}
Some((cost, expr))
Some((cost, term))
}
}

Expand Down
25 changes: 11 additions & 14 deletions src/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,14 @@ pub trait Sort: Any + Send + Sync + Debug {
let _ = info;
}

/// Extracting an expression (with smallest cost) out of a primitive value
fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr);

/// For values like EqSort containers, to make/extract an expression from it
/// requires an extractor. Moreover, the extraction may be unsuccessful if
/// the extractor is not fully initialized.
///
/// The default behavior is to call make_expr
fn extract_expr(
/// Extracting a term (with smallest cost) out of a primitive value
fn extract_term(
&self,
egraph: &EGraph,
value: Value,
_extractor: &Extractor,
_termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
Some(self.make_expr(egraph, value))
}
) -> Option<(Cost, Term)>;
}

// Note: this trait is currently intended to be implemented on the
Expand Down Expand Up @@ -161,8 +152,14 @@ impl Sort for EqSort {
}
}

fn make_expr(&self, _egraph: &EGraph, _value: Value) -> (Cost, Expr) {
unimplemented!("No make_expr for EqSort {}", self.name)
fn extract_term(
&self,
_egraph: &EGraph,
_value: Value,
_extractor: &Extractor,
_termdag: &mut TermDag,
) -> Option<(Cost, Term)> {
unimplemented!("No extract_term for EqSort {}", self.name)
}
}

Expand Down
16 changes: 4 additions & 12 deletions src/sort/multiset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,30 +262,22 @@ impl Sort for MultiSetSort {
}
}

fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) {
let mut termdag = TermDag::default();
let extractor = Extractor::new(egraph, &mut termdag);
self.extract_expr(egraph, value, &extractor, &mut termdag)
.expect("Extraction should be successful since extractor has been fully initialized")
}

fn extract_expr(
fn extract_term(
&self,
_egraph: &EGraph,
value: Value,
extractor: &Extractor,
termdag: &mut TermDag,
) -> Option<(Cost, Expr)> {
) -> Option<(Cost, Term)> {
let multiset = ValueMultiSet::load(self, &value);
let mut children = vec![];
let mut cost = 0usize;
for e in multiset.iter() {
let (child_cost, child_term) = extractor.find_best(*e, termdag, &self.element)?;
cost = cost.saturating_add(child_cost);
children.push(termdag.term_to_expr(&child_term));
children.push(child_term);
}
let expr = Expr::call_no_span("multiset-of", children);
Some((cost, expr))
Some((cost, termdag.app("multiset-of".into(), children)))
}

fn serialized_name(&self, _value: &Value) -> Symbol {
Expand Down
Loading

0 comments on commit 197103d

Please sign in to comment.