Skip to content

Commit

Permalink
Replace front-end type checker with the proposed constrained-based one
Browse files Browse the repository at this point in the history
  • Loading branch information
yihozhang committed Sep 27, 2023
1 parent 53f399a commit 08aef0c
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 97 deletions.
5 changes: 2 additions & 3 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Rule;
use crate::{typecheck::ValueEq, *};
use crate::*;

fn desugar_datatype(name: Symbol, variants: Vec<Variant>) -> Vec<NCommand> {
vec![NCommand::Sort(name, None)]
Expand Down Expand Up @@ -424,8 +424,7 @@ pub struct Desugar {

impl Default for Desugar {
fn default() -> Self {
let mut type_info = TypeInfo::default();
type_info.add_primitive(ValueEq {});
let type_info = TypeInfo::default();
Self {
next_fresh: Default::default(),
next_command_id: Default::default(),
Expand Down
141 changes: 125 additions & 16 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,118 @@ impl UnresolvedCoreRule {
self.head.subst(subst);
}
}

impl UnresolvedCoreRule {
// keep it here while it's untested
#![allow(unused)]
pub(crate) fn to_norm_rule(&self, egraph: &EGraph) -> NormRule {
fn to_fact(body: &Query<Symbol>, egraph: &EGraph) -> Vec<NormFact> {
let mut facts = vec![];
for fact in body.atoms.iter() {
let Atom { args, head } = fact;

let args: Vec<_> = args
.iter()
.map(|arg| match arg {
AtomTerm::Var(v) => v.clone(),
AtomTerm::Literal(lit) => {
let symbol = format!("lit{}", lit).into();
facts.push(NormFact::AssignLit(symbol, lit.clone()));
symbol
}
AtomTerm::Global(v) => v.clone(),
})
.collect();

// TODO: this assumes that a name can refer to a function xor a primitive
// but this does not necessarily need to be true (e.g., we can overload the definition of "+" to work over Exprs)
facts.push(if head == &"value-eq".into() {
assert!(args.len() == 2);
NormFact::ConstrainEq(args[0], args[1])
} else {
let (out, args) = args.split_last().unwrap();
if egraph.functions.contains_key(head) {
NormFact::Compute(out.clone(), NormExpr::Call(head.clone(), args.to_vec()))
} else {
NormFact::Assign(out.clone(), NormExpr::Call(head.clone(), args.to_vec()))
}
});
}
facts
}

let UnresolvedCoreRule { head, body } = self;
let Actions(head) = head.clone();
NormRule {
head,
body: to_fact(body, egraph),
}
}
}

pub(crate) fn facts_to_query(body: &Vec<NormFact>, typeinfo: &TypeInfo) -> Query<Symbol> {
fn to_atom_term(s: Symbol, typeinfo: &TypeInfo) -> AtomTerm {
if typeinfo.global_types.contains_key(&s) {
AtomTerm::Global(s)
} else {
AtomTerm::Var(s)
}
}
let mut atoms = vec![];
for fact in body {
match fact {
NormFact::Assign(symbol, NormExpr::Call(head, args))
| NormFact::Compute(symbol, NormExpr::Call(head, args)) => {
let args = args
.iter()
.chain(once(symbol))
.cloned()
.map(|s| to_atom_term(s, typeinfo))
.collect();
let head = head.clone();
atoms.push(Atom { head, args });
}
NormFact::AssignVar(lhs, rhs) => atoms.push(Atom {
head: "value-eq".into(),
args: vec![
to_atom_term(lhs.clone(), typeinfo),
to_atom_term(rhs.clone(), typeinfo),
AtomTerm::Literal(Literal::Unit),
],
}),
NormFact::ConstrainEq(lhs, rhs) => atoms.push(Atom {
head: "value-eq".into(),
args: vec![
to_atom_term(lhs.clone(), typeinfo),
to_atom_term(rhs.clone(), typeinfo),
AtomTerm::Literal(Literal::Unit),
],
}),
NormFact::AssignLit(symbol, lit) => atoms.push(Atom {
head: "value-eq".into(),
args: vec![
to_atom_term(symbol.clone(), typeinfo),
AtomTerm::Literal(lit.clone()),
AtomTerm::Literal(Literal::Unit),
],
}),
}
}
Query { atoms }
}

impl NormRule {
// keep it here while it's untested
#![allow(unused)]
pub(crate) fn to_core_rule(&self, typeinfo: &TypeInfo) -> UnresolvedCoreRule {
let NormRule { head, body } = self;
UnresolvedCoreRule {
body: facts_to_query(body, typeinfo),
head: Actions(head.clone()),
}
}
}

pub struct Context<'a> {
pub egraph: &'a mut EGraph,
pub types: IndexMap<Symbol, ArcSort>,
Expand Down Expand Up @@ -140,7 +252,7 @@ impl Query<Symbol> {
Ok(constraints)
}

fn atom_terms(&self) -> HashSet<AtomTerm> {
pub(crate) fn atom_terms(&self) -> HashSet<AtomTerm> {
self.atoms
.iter()
.flat_map(|atom| atom.args.iter().cloned())
Expand Down Expand Up @@ -208,23 +320,29 @@ impl<T> Atom<T> {
}
}

pub(crate) struct ValueEq {}
pub(crate) struct ValueEq {
pub unit: Arc<UnitSort>,
}

impl PrimitiveLike for ValueEq {
fn name(&self) -> Symbol {
"value-eq".into()
}

fn get_constraints(&self, arguments: &[AtomTerm]) -> Vec<Constraint<AtomTerm, ArcSort>> {
// TODO: egglog requires value-eq to return
// the value of the first argument upon success which is weird
all_equal_constraints(self.name(), arguments, None, Some(3), None)
all_equal_constraints(
self.name(),
arguments,
None,
Some(3),
Some(self.unit.clone()),
)
}

fn apply(&self, values: &[Value]) -> Option<Value> {
assert_eq!(values.len(), 2);
if values[0] == values[1] {
Some(values[0])
Some(Value::unit())
} else {
None
}
Expand Down Expand Up @@ -287,8 +405,7 @@ impl<'a> Context<'a> {
args: vec![
var_atoms[0].0.clone(),
va.0.clone(),
// TODO: format this
AtomTerm::Var(Symbol::from(format!("$v{}", self.unionfind.make_set()))),
AtomTerm::Literal(Literal::Unit),
],
})
.collect();
Expand All @@ -314,7 +431,6 @@ impl<'a> Context<'a> {
for fact in facts {
query += self.flatten_fact(fact);
}
// let desugar = ;
Ok(UnresolvedCoreRule {
body: query,
head: Actions(flatten_actions(actions, &mut self.egraph.desugar)),
Expand Down Expand Up @@ -436,13 +552,8 @@ impl<'a> Context<'a> {
head: actions.to_vec(),
body: facts.to_vec(),
};
// dbg!(&rule);
let rule = self.lower(&rule)?;
// dbg!(&rule);
// let rule = self.discover_primitives(rule)?;
// dbg!(&rule);
let rule = self.canonicalize(rule);
// dbg!(&rule);
let assignment = self.typecheck(&rule).map_err(|e| vec![e])?;
self.types = assignment
.clone()
Expand All @@ -457,8 +568,6 @@ impl<'a> Context<'a> {
})
.collect();
let rule = self.resolve_rule(rule, assignment).map_err(|e| vec![e])?;
// dbg!(&rule);
// let rule = self.congruence(&rule);
Ok(rule)
}
}
Expand Down
105 changes: 27 additions & 78 deletions src/typechecking.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::*;
use crate::{
typecheck::{facts_to_query, ValueEq},
*,
};

pub const RULE_PROOF_KEYWORD: &str = "rule-proof";

Expand Down Expand Up @@ -51,6 +54,10 @@ impl Default for TypeInfo {
res.presorts.insert("Set".into(), SetSort::make_sort);
res.presorts.insert("Vec".into(), VecSort::make_sort);

res.add_primitive(ValueEq {
unit: res.get_sort(),
});

res
}
}
Expand Down Expand Up @@ -259,8 +266,25 @@ impl TypeInfo {
}

fn typecheck_facts(&mut self, ctx: CommandId, facts: &Vec<NormFact>) -> Result<(), TypeError> {
for fact in facts {
self.typecheck_fact(ctx, fact)?;
// ROUND TRIP TO CORE RULE AND BACK
let query = facts_to_query(facts, self);
let constraints = query.get_constraints(self)?;
let problem = Problem { constraints };
let range = query.atom_terms();
let assignment = problem
.solve(range.iter(), |sort: &ArcSort| sort.name())
.map_err(|e| e.to_type_error())?;

for (at, ty) in assignment.0.iter() {
match at {
AtomTerm::Var(v) => {
self.introduce_binding(ctx, *v, ty.clone(), false)?;
}
// All the globals should have been introduced
AtomTerm::Global(_) => {}
// No need to bind literals as well
AtomTerm::Literal(_) => {}
}
}
Ok(())
}
Expand Down Expand Up @@ -441,81 +465,6 @@ impl TypeInfo {
Ok(())
}

fn typecheck_fact(&mut self, ctx: CommandId, fact: &NormFact) -> Result<(), TypeError> {
match fact {
NormFact::Compute(var, expr) => {
let expr_type = self.typecheck_expr(ctx, expr, true)?;
if let Some(_existing) = self
.local_types
.get_mut(&ctx)
.unwrap()
.insert(*var, expr_type.output.clone())
{
return Err(TypeError::AlreadyDefined(*var));
}
}
NormFact::Assign(var, expr) => {
let expr_type = self.typecheck_expr(ctx, expr, false)?;
if let Some(_existing) = self
.local_types
.get_mut(&ctx)
.unwrap()
.insert(*var, expr_type.output.clone())
{
return Err(TypeError::AlreadyDefined(*var));
}
}
NormFact::AssignVar(lhs, rhs) => {
let rhs_type = self.lookup(ctx, *rhs)?;
if let Some(_existing) = self
.local_types
.get_mut(&ctx)
.unwrap()
.insert(*lhs, rhs_type.clone())
{
return Err(TypeError::AlreadyDefined(*lhs));
}
}
NormFact::AssignLit(var, lit) => {
let lit_type = self.infer_literal(lit);
if let Some(existing) = self
.local_types
.get_mut(&ctx)
.unwrap()
.insert(*var, lit_type.clone())
{
if lit_type.name() != existing.name() {
return Err(TypeError::TypeMismatch(lit_type, existing));
}
}
}
NormFact::ConstrainEq(var1, var2) => {
let l1 = self.lookup(ctx, *var1);
let l2 = self.lookup(ctx, *var2);
if let Ok(v1type) = l1 {
if let Ok(v2type) = l2 {
if v1type.name() != v2type.name() {
return Err(TypeError::TypeMismatch(v1type, v2type));
}
} else {
self.local_types
.get_mut(&ctx)
.unwrap()
.insert(*var2, v1type);
}
} else if let Ok(v2type) = l2 {
self.local_types
.get_mut(&ctx)
.unwrap()
.insert(*var1, v2type);
} else {
return Err(TypeError::Unbound(*var1));
}
}
}
Ok(())
}

pub fn reserved_type(&self, sym: Symbol) -> Option<ArcSort> {
if sym == RULE_PROOF_KEYWORD.into() {
Some(self.sorts.get::<Symbol>(&"Proof__".into()).unwrap().clone())
Expand Down

0 comments on commit 08aef0c

Please sign in to comment.