Skip to content

Commit

Permalink
simplify test case
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Oct 8, 2024
1 parent 2e8f4e1 commit 167af0b
Showing 1 changed file with 84 additions and 64 deletions.
148 changes: 84 additions & 64 deletions tests/cycle_fixpoint.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/// Minimal example use case for fixpoint iteration cycle resolution.
/// Minimal(ish) example use case for fixpoint iteration cycle resolution.
use salsa::{Database as Db, Setter};
use std::collections::BTreeSet;
use std::iter::IntoIterator;

/// A Use of a symbol.
#[salsa::input]
Expand All @@ -9,13 +11,7 @@ struct Use {

#[salsa::input]
struct Literal {
value: LiteralValue,
}

#[derive(Clone, Debug)]
enum LiteralValue {
Int(usize),
Str(String),
value: usize,
}

/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`.
Expand All @@ -27,21 +23,39 @@ struct Definition {

#[derive(Eq, PartialEq, Clone, Debug)]
enum Type {
Unbound,
LiteralInt(usize),
LiteralStr(String),
Int,
Str,
Union(Vec<Type>),
Bottom,
Values(Box<[usize]>),
Top,
}

impl Type {
fn join(tys: impl IntoIterator<Item = Type>) -> Type {
let mut result = Type::Bottom;
for ty in tys.into_iter() {
result = match (result, ty) {
(result, Type::Bottom) => result,
(_, Type::Top) => Type::Top,
(Type::Top, _) => Type::Top,
(Type::Bottom, ty) => ty,
(Type::Values(a_ints), Type::Values(b_ints)) => {
let mut set = BTreeSet::new();
set.extend(a_ints);
set.extend(b_ints);
Type::Values(set.into_iter().collect())
}
}
}
result
}
}

#[salsa::tracked]
fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type {
let defs = u.reaching_definitions(db);
match defs[..] {
[] => Type::Unbound,
[] => Type::Bottom,
[def] => infer_definition(db, def),
_ => Type::Union(defs.iter().map(|&def| infer_definition(db, def)).collect()),
_ => Type::join(defs.iter().map(|&def| infer_definition(db, def))),
}
}

Expand All @@ -50,111 +64,117 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type {
let increment_ty = infer_literal(db, def.increment(db));
if let Some(base) = def.base(db) {
let base_ty = infer_use(db, base);
match (base_ty, increment_ty) {
(Type::Unbound, _) => panic!("unbound use"),
(Type::LiteralInt(b), Type::LiteralInt(i)) => Type::LiteralInt(b + i),
(Type::LiteralStr(b), Type::LiteralStr(i)) => Type::LiteralStr(format!("{}{}", b, i)),
(Type::Int, Type::LiteralInt(_)) => Type::Int,
(Type::LiteralInt(_), Type::Int) => Type::Int,
(Type::Str, Type::LiteralStr(_)) => Type::Str,
(Type::LiteralStr(_), Type::Str) => Type::Str,
_ => panic!("type error"),
}
add(&base_ty, &increment_ty)
} else {
increment_ty
}
}

fn add(a: &Type, b: &Type) -> Type {
match (a, b) {
(Type::Bottom, _) | (_, Type::Bottom) => panic!("unbound use"),
(Type::Top, _) | (_, Type::Top) => Type::Top,
(Type::Values(a_ints), Type::Values(b_ints)) => {
let mut set = BTreeSet::new();
set.extend(
a_ints
.into_iter()
.flat_map(|a| b_ints.into_iter().map(move |b| a + b)),
);
Type::Values(set.into_iter().collect())
}
}
}

#[salsa::tracked]
fn infer_literal<'db>(db: &'db dyn Db, literal: Literal) -> Type {
match literal.value(db) {
LiteralValue::Int(i) => Type::LiteralInt(i),
LiteralValue::Str(s) => Type::LiteralStr(s),
}
Type::Values(Box::from([literal.value(db)]))
}

/// x = 1
#[test]
fn simple() {
let db = salsa::DatabaseImpl::new();

let def = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1)));
let def = Definition::new(&db, None, Literal::new(&db, 1));
let u = Use::new(&db, vec![def]);

let ty = infer_use(&db, u);

assert_eq!(ty, Type::LiteralInt(1));
assert_eq!(ty, Type::Values(Box::from([1])));
}

/// x = "a" if flag else "b"
/// x = 1 if flag else 2
#[test]
fn union() {
let db = salsa::DatabaseImpl::new();

let def1 = Definition::new(
&db,
None,
Literal::new(&db, LiteralValue::Str("a".to_string())),
);
let def2 = Definition::new(
&db,
None,
Literal::new(&db, LiteralValue::Str("b".to_string())),
);
let def1 = Definition::new(&db, None, Literal::new(&db, 1));
let def2 = Definition::new(&db, None, Literal::new(&db, 2));
let u = Use::new(&db, vec![def1, def2]);

let ty = infer_use(&db, u);

assert_eq!(
ty,
Type::Union(vec![
Type::LiteralStr("a".to_string()),
Type::LiteralStr("b".to_string())
])
);
assert_eq!(ty, Type::Values(Box::from([1, 2])));
}

/// x = 1 if flag else 2; y = x + 1
#[test]
fn union_add() {
let db = salsa::DatabaseImpl::new();

let x1 = Definition::new(&db, None, Literal::new(&db, 1));
let x2 = Definition::new(&db, None, Literal::new(&db, 2));
let x_use = Use::new(&db, vec![x1, x2]);
let y_def = Definition::new(&db, Some(x_use), Literal::new(&db, 1));
let y_use = Use::new(&db, vec![y_def]);

let ty = infer_use(&db, y_use);

assert_eq!(ty, Type::Values(Box::from([2, 3])));
}

/// x = 1; loop { x = x + 0 }
#[test]
fn cycle_converges() {
let mut db = salsa::DatabaseImpl::new();

let def1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1)));
let def2 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0)));
let def1 = Definition::new(&db, None, Literal::new(&db, 1));
let def2 = Definition::new(&db, None, Literal::new(&db, 0));
let u = Use::new(&db, vec![def1, def2]);
def2.set_base(&mut db).to(Some(u));

let ty = infer_use(&db, u);

// Loop converges on LiteralInt(1)
assert_eq!(ty, Type::LiteralInt(1));
// Loop converges on 1
assert_eq!(ty, Type::Values(Box::from([1])));
}

/// x = 1; loop { x = x + 1 }
#[test]
fn cycle_diverges() {
let mut db = salsa::DatabaseImpl::new();

let def1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1)));
let def2 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1)));
let def1 = Definition::new(&db, None, Literal::new(&db, 1));
let def2 = Definition::new(&db, None, Literal::new(&db, 1));
let u = Use::new(&db, vec![def1, def2]);
def2.set_base(&mut db).to(Some(u));

let ty = infer_use(&db, u);

// Loop diverges. Cut it off and fallback from "all LiteralInt observed" to Type::Int
assert_eq!(ty, Type::Int);
// Loop diverges. Cut it off and fallback to Type::Top
assert_eq!(ty, Type::Top);
}

/// x = 0; y = 0; loop { x = y + 0; y = x + 0 }
#[test]
fn multi_symbol_cycle_converges() {
let mut db = salsa::DatabaseImpl::new();

let defx0 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0)));
let defy0 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0)));
let defx1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0)));
let defy1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0)));
let defx0 = Definition::new(&db, None, Literal::new(&db, 0));
let defy0 = Definition::new(&db, None, Literal::new(&db, 0));
let defx1 = Definition::new(&db, None, Literal::new(&db, 0));
let defy1 = Definition::new(&db, None, Literal::new(&db, 0));
let use_x = Use::new(&db, vec![defx0, defx1]);
let use_y = Use::new(&db, vec![defy0, defy1]);
defx1.set_base(&mut db).to(Some(use_y));
Expand All @@ -164,6 +184,6 @@ fn multi_symbol_cycle_converges() {
let y_ty = infer_use(&db, use_y);

// Both symbols converge on LiteralInt(0)
assert_eq!(x_ty, Type::LiteralInt(0));
assert_eq!(y_ty, Type::LiteralInt(0));
assert_eq!(x_ty, Type::Values(Box::from([0])));
assert_eq!(y_ty, Type::Values(Box::from([0])));
}

0 comments on commit 167af0b

Please sign in to comment.