Skip to content

Commit

Permalink
Merge pull request #146 from lixitrixi/manual-rewrite-poc
Browse files Browse the repository at this point in the history
Manual Rewrite Proof of Concept
  • Loading branch information
ozgurakgun authored Jan 26, 2024
2 parents 90e1835 + f0eb6d9 commit 4828e4d
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 10 deletions.
4 changes: 2 additions & 2 deletions conjure_oxide/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod error;
pub mod find_conjure;
pub mod parse;
pub mod rules;
mod solvers;
mod rules;
pub mod solvers;

pub use conjure_core::ast; // re-export core::ast as conjure_oxide::ast
pub use conjure_core::ast::Model; // rexport core::ast::Model as conjure_oxide::Model
Expand Down
8 changes: 7 additions & 1 deletion conjure_oxide/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ struct Cli {
}

pub fn main() -> AnyhowResult<()> {
println!("Rules: {:?}", conjure_rules::get_rules());
println!(
"Rules: {:?}",
conjure_rules::get_rules()
.iter()
.map(|r| r.name)
.collect::<Vec<_>>()
);

let cli = Cli::parse();
println!("Input file: {}", cli.input_file.display());
Expand Down
80 changes: 77 additions & 3 deletions conjure_oxide/src/rules/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,81 @@
use conjure_core::{ast::Expression, rule::RuleApplicationError};
use conjure_core::{ast::Expression as Expr, rule::RuleApplicationError};
use conjure_rules::register_rule;

// #[register_rule]
// fn identity(expr: &Expr) -> Result<Expr, RuleApplicationError> {
// Ok(expr.clone())
// }

#[register_rule]
fn sum_constants(expr: &Expr) -> Result<Expr, RuleApplicationError> {
match expr {
Expr::Sum(exprs) => {
let mut sum = 0;
let mut new_exprs = Vec::new();
let mut changed = false;
for e in exprs {
match e {
Expr::ConstantInt(i) => {
sum += i;
changed = true;
}
_ => new_exprs.push(e.clone()),
}
}
if !changed {
return Err(RuleApplicationError::RuleNotApplicable);
}
new_exprs.push(Expr::ConstantInt(sum));
Ok(Expr::Sum(new_exprs)) // Let other rules handle only one Expr being contained in the sum
}
_ => Err(RuleApplicationError::RuleNotApplicable),
}
}

#[register_rule]
fn unwrap_sum(expr: &Expr) -> Result<Expr, RuleApplicationError> {
match expr {
Expr::Sum(exprs) if (exprs.len() == 1) => Ok(exprs[0].clone()),
_ => Err(RuleApplicationError::RuleNotApplicable),
}
}

#[register_rule]
fn flatten_sum_geq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
match expr {
Expr::Geq(a, b) => {
let exprs = match a.as_ref() {
Expr::Sum(exprs) => Ok(exprs),
_ => Err(RuleApplicationError::RuleNotApplicable),
}?;
Ok(Expr::SumGeq(exprs.clone(), b.clone()))
}
_ => Err(RuleApplicationError::RuleNotApplicable),
}
}

#[register_rule]
fn sum_leq_to_sumleq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
match expr {
Expr::Leq(a, b) => {
let exprs = match a.as_ref() {
Expr::Sum(exprs) => Ok(exprs),
_ => Err(RuleApplicationError::RuleNotApplicable),
}?;
Ok(Expr::SumLeq(exprs.clone(), b.clone()))
}
_ => Err(RuleApplicationError::RuleNotApplicable),
}
}

#[register_rule]
fn identity(expr: &Expression) -> Result<Expression, RuleApplicationError> {
Ok(expr.clone())
fn lt_to_ineq(expr: &Expr) -> Result<Expr, RuleApplicationError> {
match expr {
Expr::Lt(a, b) => Ok(Expr::Ineq(
a.clone(),
b.clone(),
Box::new(Expr::ConstantInt(-1)),
)),
_ => Err(RuleApplicationError::RuleNotApplicable),
}
}
5 changes: 3 additions & 2 deletions conjure_oxide/src/solvers/minion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use crate::ast::{
DecisionVariable, Domain as ConjureDomain, Expression as ConjureExpression,
Model as ConjureModel, Name as ConjureName, Range as ConjureRange,
};
pub use minion_rs::ast::Model as MinionModel;
use minion_rs::ast::{
Constant as MinionConstant, Constraint as MinionConstraint, Model as MinionModel,
Var as MinionVar, VarDomain as MinionDomain,
Constant as MinionConstant, Constraint as MinionConstraint, Var as MinionVar,
VarDomain as MinionDomain,
};

const SOLVER: Solver = Solver::Minion;
Expand Down
167 changes: 165 additions & 2 deletions conjure_oxide/tests/rewrite_tests.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
// Tests for rewriting/simplifying parts of the AST

use core::panic;
use std::collections::HashMap;

use conjure_oxide::ast::*;
use conjure_oxide::{
ast::*,
solvers::{minion, FromConjureModel},
};
use conjure_rules::{get_rule_by_name, get_rules};
use minion_rs::ast::{Constant, VarName};

#[test]
fn rules_present() {
let rules = conjure_rules::get_rules();
let rules = get_rules();
assert!(rules.len() > 0);
}

Expand Down Expand Up @@ -97,3 +103,160 @@ fn simplify_expression(expr: Expression) -> Expression {
_ => expr,
}
}

#[test]
fn rule_sum_constants() {
let sum_constants = get_rule_by_name("sum_constants").unwrap();
let unwrap_sum = get_rule_by_name("unwrap_sum").unwrap();

let mut expr = Expression::Sum(vec![
Expression::ConstantInt(1),
Expression::ConstantInt(2),
Expression::ConstantInt(3),
]);

expr = sum_constants.apply(&expr).unwrap();
expr = unwrap_sum.apply(&expr).unwrap();

assert_eq!(expr, Expression::ConstantInt(6));
}

#[test]
fn rule_sum_mixed() {
let sum_constants = get_rule_by_name("sum_constants").unwrap();

let mut expr = Expression::Sum(vec![
Expression::ConstantInt(1),
Expression::ConstantInt(2),
Expression::Reference(Name::UserName(String::from("a"))),
]);

expr = sum_constants.apply(&expr).unwrap();

assert_eq!(
expr,
Expression::Sum(vec![
Expression::Reference(Name::UserName(String::from("a"))),
Expression::ConstantInt(3),
])
);
}

#[test]
fn rule_sum_geq() {
let flatten_sum_geq = get_rule_by_name("flatten_sum_geq").unwrap();

let mut expr = Expression::Geq(
Box::new(Expression::Sum(vec![
Expression::ConstantInt(1),
Expression::ConstantInt(2),
])),
Box::new(Expression::ConstantInt(3)),
);

expr = flatten_sum_geq.apply(&expr).unwrap();

assert_eq!(
expr,
Expression::SumGeq(
vec![Expression::ConstantInt(1), Expression::ConstantInt(2),],
Box::new(Expression::ConstantInt(3))
)
);
}

fn callback(solution: HashMap<VarName, Constant>) -> bool {
println!("Solution: {:?}", solution);
false
}

///
/// Reduce and solve:
/// ```text
/// find a,b,c : int(1..3)
/// such that a + b + c <= 2 + 3 - 1
/// such that a < b
/// ```
#[test]
fn reduce_solve_xyz() {
println!("Rules: {:?}", conjure_rules::get_rules());
let sum_constants = get_rule_by_name("sum_constants").unwrap();
let unwrap_sum = get_rule_by_name("unwrap_sum").unwrap();
let lt_to_ineq = get_rule_by_name("lt_to_ineq").unwrap();
let sum_leq_to_sumleq = get_rule_by_name("sum_leq_to_sumleq").unwrap();

// 2 + 3 - 1
let mut expr1 = Expression::Sum(vec![
Expression::ConstantInt(2),
Expression::ConstantInt(3),
Expression::ConstantInt(-1),
]);

expr1 = sum_constants.apply(&expr1).unwrap();
expr1 = unwrap_sum.apply(&expr1).unwrap();
assert_eq!(expr1, Expression::ConstantInt(4));

// a + b + c = 4
expr1 = Expression::Leq(
Box::new(Expression::Sum(vec![
Expression::Reference(Name::UserName(String::from("a"))),
Expression::Reference(Name::UserName(String::from("b"))),
Expression::Reference(Name::UserName(String::from("c"))),
])),
Box::new(expr1),
);
expr1 = sum_leq_to_sumleq.apply(&expr1).unwrap();
assert_eq!(
expr1,
Expression::SumLeq(
vec![
Expression::Reference(Name::UserName(String::from("a"))),
Expression::Reference(Name::UserName(String::from("b"))),
Expression::Reference(Name::UserName(String::from("c"))),
],
Box::new(Expression::ConstantInt(4))
)
);

// a < b
let mut expr2 = Expression::Lt(
Box::new(Expression::Reference(Name::UserName(String::from("a")))),
Box::new(Expression::Reference(Name::UserName(String::from("b")))),
);
expr2 = lt_to_ineq.apply(&expr2).unwrap();
assert_eq!(
expr2,
Expression::Ineq(
Box::new(Expression::Reference(Name::UserName(String::from("a")))),
Box::new(Expression::Reference(Name::UserName(String::from("b")))),
Box::new(Expression::ConstantInt(-1))
)
);

let mut model = Model {
variables: HashMap::new(),
constraints: vec![expr1, expr2],
};
model.variables.insert(
Name::UserName(String::from("a")),
DecisionVariable {
domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]),
},
);
model.variables.insert(
Name::UserName(String::from("b")),
DecisionVariable {
domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]),
},
);
model.variables.insert(
Name::UserName(String::from("c")),
DecisionVariable {
domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]),
},
);

let minion_model = conjure_oxide::solvers::minion::MinionModel::from_conjure(model).unwrap();

minion_rs::run_minion(minion_model, callback).unwrap();
}
4 changes: 4 additions & 0 deletions crates/conjure_rules/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ pub fn get_rules() -> Vec<Rule<'static>> {
RULES_DISTRIBUTED_SLICE.to_vec()
}

pub fn get_rule_by_name(name: &str) -> Option<Rule<'static>> {
get_rules().iter().find(|rule| rule.name == name).cloned()
}

/// This procedural macro registers a decorated function with `conjure_rules`' global registry.
/// It may be used in any downstream crate. For more information on linker magic, see the [`linkme`](https://docs.rs/linkme/latest/linkme/) crate.
///
Expand Down

0 comments on commit 4828e4d

Please sign in to comment.