diff --git a/src/common/ast/ast.rs b/src/common/ast/ast.rs index 90f583f13..f1d298e1e 100644 --- a/src/common/ast/ast.rs +++ b/src/common/ast/ast.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct Model { pub variables: HashMap, pub constraints: Vec, @@ -9,6 +9,7 @@ pub struct Model { impl Model { // Function to update a DecisionVariable based on its Name pub fn update_domain(&mut self, name: &Name, new_domain: Domain) { + assert!(self.variables.contains_key(name)); if let Some(decision_var) = self.variables.get_mut(name) { decision_var.domain = new_domain; } diff --git a/src/common/ast/mod.rs b/src/common/ast/mod.rs index 98f77c11f..962768cb7 100644 --- a/src/common/ast/mod.rs +++ b/src/common/ast/mod.rs @@ -1,2 +1,5 @@ mod ast; pub use ast::*; + +mod model_builder; +pub use model_builder::*; diff --git a/src/common/ast/model_builder.rs b/src/common/ast/model_builder.rs new file mode 100644 index 000000000..ac224f8d3 --- /dev/null +++ b/src/common/ast/model_builder.rs @@ -0,0 +1,39 @@ +use std::collections::HashMap; +use super::ast::*; + +pub struct ModelBuilder { + pub variables: HashMap, + pub constraints: Vec, +} + +impl ModelBuilder { + pub fn new() -> ModelBuilder { + ModelBuilder { + variables: HashMap::new(), + constraints: Vec::new(), + } + } + + pub fn add_constraint(mut self, constraint: Expression) -> Self { + self.constraints.push(constraint); + self + } + + pub fn add_var(mut self, name: Name, domain: Domain) -> Self { + assert!(self.variables.get(&name).is_none()); + self.variables.insert( + name, + DecisionVariable { + domain, + }, + ); + self + } + + pub fn build(self) -> Model { + Model { + variables: self.variables, + constraints: self.constraints, + } + } +} diff --git a/tests/builder_tests.rs b/tests/builder_tests.rs new file mode 100644 index 000000000..659ad96cb --- /dev/null +++ b/tests/builder_tests.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +use conjure_oxide::ast::*; + +#[test] +fn abc_equality() { + let a = Name::UserName(String::from("a")); + let b = Name::UserName(String::from("b")); + let c = Name::UserName(String::from("c")); + + let mut variables = HashMap::new(); + variables.insert( + a.clone(), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + variables.insert( + b.clone(), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + variables.insert( + c.clone(), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + + let m1 = Model { + variables, + constraints: vec![ + Expression::Eq( + Box::new(Expression::Sum(vec![ + Expression::Reference(a.clone()), + Expression::Reference(b.clone()), + Expression::Reference(c.clone()), + ])), + Box::new(Expression::ConstantInt(4)), + ), + Expression::Geq( + Box::new(Expression::Reference(a.clone())), + Box::new(Expression::Reference(b.clone())), + ), + ], + }; + + let m2 = ModelBuilder::new() + .add_var(a.clone(), Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var(b.clone(), Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var(c.clone(), Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_constraint(Expression::Eq( + Box::new(Expression::Sum(vec![ + Expression::Reference(a.clone()), + Expression::Reference(b.clone()), + Expression::Reference(c.clone()), + ])), + Box::new(Expression::ConstantInt(4)), + )) + .add_constraint(Expression::Geq( + Box::new(Expression::Reference(a.clone())), + Box::new(Expression::Reference(b.clone())), + )) + .build(); + + assert!(m1 == m2); +}