diff --git a/conjure_oxide/src/lib.rs b/conjure_oxide/src/lib.rs index e7f1a854a..69c2e0bb0 100644 --- a/conjure_oxide/src/lib.rs +++ b/conjure_oxide/src/lib.rs @@ -1,8 +1,8 @@ pub mod error; pub mod find_conjure; pub mod parse; -pub mod rules; -mod solvers; +mod rules; +pub mod solvers; pub use conjure_core::ast; // re-export core::ast as conjure_oxide::ast pub use conjure_core::ast::Model; // rexport core::ast::Model as conjure_oxide::Model diff --git a/conjure_oxide/src/main.rs b/conjure_oxide/src/main.rs index 0699ff379..e382740ce 100644 --- a/conjure_oxide/src/main.rs +++ b/conjure_oxide/src/main.rs @@ -19,7 +19,13 @@ struct Cli { } pub fn main() -> AnyhowResult<()> { - println!("Rules: {:?}", conjure_rules::get_rules()); + println!( + "Rules: {:?}", + conjure_rules::get_rules() + .iter() + .map(|r| r.name) + .collect::>() + ); let cli = Cli::parse(); println!("Input file: {}", cli.input_file.display()); diff --git a/conjure_oxide/src/rules/mod.rs b/conjure_oxide/src/rules/mod.rs index 4d4702c64..09f07fe44 100644 --- a/conjure_oxide/src/rules/mod.rs +++ b/conjure_oxide/src/rules/mod.rs @@ -1,7 +1,81 @@ -use conjure_core::{ast::Expression, rule::RuleApplicationError}; +use conjure_core::{ast::Expression as Expr, rule::RuleApplicationError}; use conjure_rules::register_rule; +// #[register_rule] +// fn identity(expr: &Expr) -> Result { +// Ok(expr.clone()) +// } + +#[register_rule] +fn sum_constants(expr: &Expr) -> Result { + match expr { + Expr::Sum(exprs) => { + let mut sum = 0; + let mut new_exprs = Vec::new(); + let mut changed = false; + for e in exprs { + match e { + Expr::ConstantInt(i) => { + sum += i; + changed = true; + } + _ => new_exprs.push(e.clone()), + } + } + if !changed { + return Err(RuleApplicationError::RuleNotApplicable); + } + new_exprs.push(Expr::ConstantInt(sum)); + Ok(Expr::Sum(new_exprs)) // Let other rules handle only one Expr being contained in the sum + } + _ => Err(RuleApplicationError::RuleNotApplicable), + } +} + +#[register_rule] +fn unwrap_sum(expr: &Expr) -> Result { + match expr { + Expr::Sum(exprs) if (exprs.len() == 1) => Ok(exprs[0].clone()), + _ => Err(RuleApplicationError::RuleNotApplicable), + } +} + +#[register_rule] +fn flatten_sum_geq(expr: &Expr) -> Result { + match expr { + Expr::Geq(a, b) => { + let exprs = match a.as_ref() { + Expr::Sum(exprs) => Ok(exprs), + _ => Err(RuleApplicationError::RuleNotApplicable), + }?; + Ok(Expr::SumGeq(exprs.clone(), b.clone())) + } + _ => Err(RuleApplicationError::RuleNotApplicable), + } +} + +#[register_rule] +fn sum_leq_to_sumleq(expr: &Expr) -> Result { + match expr { + Expr::Leq(a, b) => { + let exprs = match a.as_ref() { + Expr::Sum(exprs) => Ok(exprs), + _ => Err(RuleApplicationError::RuleNotApplicable), + }?; + Ok(Expr::SumLeq(exprs.clone(), b.clone())) + } + _ => Err(RuleApplicationError::RuleNotApplicable), + } +} + #[register_rule] -fn identity(expr: &Expression) -> Result { - Ok(expr.clone()) +fn lt_to_ineq(expr: &Expr) -> Result { + match expr { + Expr::Lt(a, b) => Ok(Expr::Ineq( + a.clone(), + b.clone(), + Box::new(Expr::ConstantInt(-1)), + )), + _ => Err(RuleApplicationError::RuleNotApplicable), + } } diff --git a/conjure_oxide/src/solvers/minion.rs b/conjure_oxide/src/solvers/minion.rs index f3eea6095..76ab90414 100644 --- a/conjure_oxide/src/solvers/minion.rs +++ b/conjure_oxide/src/solvers/minion.rs @@ -7,9 +7,10 @@ use crate::ast::{ DecisionVariable, Domain as ConjureDomain, Expression as ConjureExpression, Model as ConjureModel, Name as ConjureName, Range as ConjureRange, }; +pub use minion_rs::ast::Model as MinionModel; use minion_rs::ast::{ - Constant as MinionConstant, Constraint as MinionConstraint, Model as MinionModel, - Var as MinionVar, VarDomain as MinionDomain, + Constant as MinionConstant, Constraint as MinionConstraint, Var as MinionVar, + VarDomain as MinionDomain, }; const SOLVER: Solver = Solver::Minion; diff --git a/conjure_oxide/tests/rewrite_tests.rs b/conjure_oxide/tests/rewrite_tests.rs index 0061c6efe..865bb5d56 100644 --- a/conjure_oxide/tests/rewrite_tests.rs +++ b/conjure_oxide/tests/rewrite_tests.rs @@ -1,12 +1,18 @@ // Tests for rewriting/simplifying parts of the AST use core::panic; +use std::collections::HashMap; -use conjure_oxide::ast::*; +use conjure_oxide::{ + ast::*, + solvers::{minion, FromConjureModel}, +}; +use conjure_rules::{get_rule_by_name, get_rules}; +use minion_rs::ast::{Constant, VarName}; #[test] fn rules_present() { - let rules = conjure_rules::get_rules(); + let rules = get_rules(); assert!(rules.len() > 0); } @@ -97,3 +103,160 @@ fn simplify_expression(expr: Expression) -> Expression { _ => expr, } } + +#[test] +fn rule_sum_constants() { + let sum_constants = get_rule_by_name("sum_constants").unwrap(); + let unwrap_sum = get_rule_by_name("unwrap_sum").unwrap(); + + let mut expr = Expression::Sum(vec![ + Expression::ConstantInt(1), + Expression::ConstantInt(2), + Expression::ConstantInt(3), + ]); + + expr = sum_constants.apply(&expr).unwrap(); + expr = unwrap_sum.apply(&expr).unwrap(); + + assert_eq!(expr, Expression::ConstantInt(6)); +} + +#[test] +fn rule_sum_mixed() { + let sum_constants = get_rule_by_name("sum_constants").unwrap(); + + let mut expr = Expression::Sum(vec![ + Expression::ConstantInt(1), + Expression::ConstantInt(2), + Expression::Reference(Name::UserName(String::from("a"))), + ]); + + expr = sum_constants.apply(&expr).unwrap(); + + assert_eq!( + expr, + Expression::Sum(vec![ + Expression::Reference(Name::UserName(String::from("a"))), + Expression::ConstantInt(3), + ]) + ); +} + +#[test] +fn rule_sum_geq() { + let flatten_sum_geq = get_rule_by_name("flatten_sum_geq").unwrap(); + + let mut expr = Expression::Geq( + Box::new(Expression::Sum(vec![ + Expression::ConstantInt(1), + Expression::ConstantInt(2), + ])), + Box::new(Expression::ConstantInt(3)), + ); + + expr = flatten_sum_geq.apply(&expr).unwrap(); + + assert_eq!( + expr, + Expression::SumGeq( + vec![Expression::ConstantInt(1), Expression::ConstantInt(2),], + Box::new(Expression::ConstantInt(3)) + ) + ); +} + +fn callback(solution: HashMap) -> bool { + println!("Solution: {:?}", solution); + false +} + +/// +/// Reduce and solve: +/// ```text +/// find a,b,c : int(1..3) +/// such that a + b + c <= 2 + 3 - 1 +/// such that a < b +/// ``` +#[test] +fn reduce_solve_xyz() { + println!("Rules: {:?}", conjure_rules::get_rules()); + let sum_constants = get_rule_by_name("sum_constants").unwrap(); + let unwrap_sum = get_rule_by_name("unwrap_sum").unwrap(); + let lt_to_ineq = get_rule_by_name("lt_to_ineq").unwrap(); + let sum_leq_to_sumleq = get_rule_by_name("sum_leq_to_sumleq").unwrap(); + + // 2 + 3 - 1 + let mut expr1 = Expression::Sum(vec![ + Expression::ConstantInt(2), + Expression::ConstantInt(3), + Expression::ConstantInt(-1), + ]); + + expr1 = sum_constants.apply(&expr1).unwrap(); + expr1 = unwrap_sum.apply(&expr1).unwrap(); + assert_eq!(expr1, Expression::ConstantInt(4)); + + // a + b + c = 4 + expr1 = Expression::Leq( + Box::new(Expression::Sum(vec![ + Expression::Reference(Name::UserName(String::from("a"))), + Expression::Reference(Name::UserName(String::from("b"))), + Expression::Reference(Name::UserName(String::from("c"))), + ])), + Box::new(expr1), + ); + expr1 = sum_leq_to_sumleq.apply(&expr1).unwrap(); + assert_eq!( + expr1, + Expression::SumLeq( + vec![ + Expression::Reference(Name::UserName(String::from("a"))), + Expression::Reference(Name::UserName(String::from("b"))), + Expression::Reference(Name::UserName(String::from("c"))), + ], + Box::new(Expression::ConstantInt(4)) + ) + ); + + // a < b + let mut expr2 = Expression::Lt( + Box::new(Expression::Reference(Name::UserName(String::from("a")))), + Box::new(Expression::Reference(Name::UserName(String::from("b")))), + ); + expr2 = lt_to_ineq.apply(&expr2).unwrap(); + assert_eq!( + expr2, + Expression::Ineq( + Box::new(Expression::Reference(Name::UserName(String::from("a")))), + Box::new(Expression::Reference(Name::UserName(String::from("b")))), + Box::new(Expression::ConstantInt(-1)) + ) + ); + + let mut model = Model { + variables: HashMap::new(), + constraints: vec![expr1, expr2], + }; + model.variables.insert( + Name::UserName(String::from("a")), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + model.variables.insert( + Name::UserName(String::from("b")), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + model.variables.insert( + Name::UserName(String::from("c")), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + + let minion_model = conjure_oxide::solvers::minion::MinionModel::from_conjure(model).unwrap(); + + minion_rs::run_minion(minion_model, callback).unwrap(); +} diff --git a/crates/conjure_rules/src/lib.rs b/crates/conjure_rules/src/lib.rs index 7ce4b1e87..a3064653d 100644 --- a/crates/conjure_rules/src/lib.rs +++ b/crates/conjure_rules/src/lib.rs @@ -64,6 +64,10 @@ pub fn get_rules() -> Vec> { RULES_DISTRIBUTED_SLICE.to_vec() } +pub fn get_rule_by_name(name: &str) -> Option> { + get_rules().iter().find(|rule| rule.name == name).cloned() +} + /// This procedural macro registers a decorated function with `conjure_rules`' global registry. /// It may be used in any downstream crate. For more information on linker magic, see the [`linkme`](https://docs.rs/linkme/latest/linkme/) crate. ///