From 0e827c9a458e54e503d85586b8ea8f1ce1f22d22 Mon Sep 17 00:00:00 2001 From: Felix Date: Mon, 23 Oct 2023 16:51:00 +0100 Subject: [PATCH 1/5] Initial builder definitions --- src/main.rs | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index f81f8f17c..d244cc4aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, rc::Rc, collections::HashMap}; fn main() { let a = Name::UserName(String::from("a")); @@ -51,6 +51,52 @@ fn main() { println!("{:#?}", m); } +struct ModelBuilder { + statements: Vec, + variables: HashMap>>, +} + +impl ModelBuilder { + fn new() -> Self { + ModelBuilder { + statements: Vec::new(), + variables: HashMap::new(), + } + } + + fn add_statement(mut self, statement: Statement) -> Self { + self.statements.push(statement); + self + } + + fn find(self, name: String, domain: Domain) -> Self { + let var = Rc::new(RefCell::new(DecisionVariable { + name: Name::UserName(name), + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + })); + let statement: Statement = Statement::Declaration(var); + self.add_statement(statement) + } + + fn such_that(self, expression: Expression) -> Self { + self.add_statement(Statement::Constraint(expression)) + } + + fn build(self) -> Model { + Model { + statements: self.statements, + } + } + + // Return an expression that references the given variable, if previously defined via a `find` statement. + fn var_expr(self, name: String) -> Option { + match self.variables.get(&name) { + Some(var) => Some(Expression::Reference(Rc::clone(var))), + None => None, + } + } +} + #[derive(Debug)] enum Name { UserName(String), From 154614148172355217206f6a9db025c107b565bb Mon Sep 17 00:00:00 2001 From: Felix Date: Fri, 27 Oct 2023 01:12:17 +0100 Subject: [PATCH 2/5] Create ModelBuilder --- src/common/ast/mod.rs | 3 +++ src/common/ast/model_builder.rs | 48 +++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 src/common/ast/model_builder.rs diff --git a/src/common/ast/mod.rs b/src/common/ast/mod.rs index 98f77c11f..962768cb7 100644 --- a/src/common/ast/mod.rs +++ b/src/common/ast/mod.rs @@ -1,2 +1,5 @@ mod ast; pub use ast::*; + +mod model_builder; +pub use model_builder::*; diff --git a/src/common/ast/model_builder.rs b/src/common/ast/model_builder.rs new file mode 100644 index 000000000..5e46335bf --- /dev/null +++ b/src/common/ast/model_builder.rs @@ -0,0 +1,48 @@ +use std::collections::HashMap; +use super::ast::*; + +pub struct ModelBuilder { + pub variables: HashMap, + pub constraints: Vec, +} + +impl ModelBuilder { + pub fn new() -> ModelBuilder { + ModelBuilder { + variables: HashMap::new(), + constraints: Vec::new(), + } + } + + pub fn add_constraint(mut self, constraint: Expression) -> Self { + self.constraints.push(constraint); + self + } + + pub fn add_var(mut self, name: Name, domain: Domain) -> Self { + self.variables.insert( + name, + DecisionVariable { + domain, + }, + ); + self + } + + pub fn add_var_str(mut self, name: &str, domain: Domain) -> Self { + self.variables.insert( + Name::UserName(String::from(name)), + DecisionVariable { + domain, + }, + ); + self + } + + pub fn build(self) -> Model { + Model { + variables: self.variables, + constraints: self.constraints, + } + } +} From 022085282fc5b156bb715c4316ba66cc75881ace Mon Sep 17 00:00:00 2001 From: Felix Date: Fri, 27 Oct 2023 01:12:48 +0100 Subject: [PATCH 3/5] Add builder tests --- src/common/ast/ast.rs | 2 +- tests/builder_tests.rs | 68 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/builder_tests.rs diff --git a/src/common/ast/ast.rs b/src/common/ast/ast.rs index 90f583f13..64c47a107 100644 --- a/src/common/ast/ast.rs +++ b/src/common/ast/ast.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct Model { pub variables: HashMap, pub constraints: Vec, diff --git a/tests/builder_tests.rs b/tests/builder_tests.rs new file mode 100644 index 000000000..7716c0947 --- /dev/null +++ b/tests/builder_tests.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +use conjure_oxide::ast::*; + +#[test] +fn abc_equality() { + let a = Name::UserName(String::from("a")); + let b = Name::UserName(String::from("b")); + let c = Name::UserName(String::from("c")); + + let mut variables = HashMap::new(); + variables.insert( + a.clone(), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + variables.insert( + b.clone(), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + variables.insert( + c.clone(), + DecisionVariable { + domain: Domain::IntDomain(vec![Range::Bounded(1, 3)]), + }, + ); + + let m1 = Model { + variables, + constraints: vec![ + Expression::Eq( + Box::new(Expression::Sum(vec![ + Expression::Reference(a.clone()), + Expression::Reference(b.clone()), + Expression::Reference(c.clone()), + ])), + Box::new(Expression::ConstantInt(4)), + ), + Expression::Geq( + Box::new(Expression::Reference(a.clone())), + Box::new(Expression::Reference(b.clone())), + ), + ], + }; + + let m2 = ModelBuilder::new() + .add_var_str("a", Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var_str("b", Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var_str("c", Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_constraint(Expression::Eq( + Box::new(Expression::Sum(vec![ + Expression::Reference(a.clone()), + Expression::Reference(b.clone()), + Expression::Reference(c.clone()), + ])), + Box::new(Expression::ConstantInt(4)), + )) + .add_constraint(Expression::Geq( + Box::new(Expression::Reference(a.clone())), + Box::new(Expression::Reference(b.clone())), + )) + .build(); + + assert!(m1 == m2); +} From 262b5cbeb5bf7d184ab77aac079572de25ffb426 Mon Sep 17 00:00:00 2001 From: Felix Date: Fri, 27 Oct 2023 11:56:52 +0100 Subject: [PATCH 4/5] Remove string var addition method --- src/common/ast/model_builder.rs | 10 ---------- tests/builder_tests.rs | 6 +++--- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/common/ast/model_builder.rs b/src/common/ast/model_builder.rs index 5e46335bf..ad0830833 100644 --- a/src/common/ast/model_builder.rs +++ b/src/common/ast/model_builder.rs @@ -29,16 +29,6 @@ impl ModelBuilder { self } - pub fn add_var_str(mut self, name: &str, domain: Domain) -> Self { - self.variables.insert( - Name::UserName(String::from(name)), - DecisionVariable { - domain, - }, - ); - self - } - pub fn build(self) -> Model { Model { variables: self.variables, diff --git a/tests/builder_tests.rs b/tests/builder_tests.rs index 7716c0947..659ad96cb 100644 --- a/tests/builder_tests.rs +++ b/tests/builder_tests.rs @@ -47,9 +47,9 @@ fn abc_equality() { }; let m2 = ModelBuilder::new() - .add_var_str("a", Domain::IntDomain(vec![Range::Bounded(1, 3)])) - .add_var_str("b", Domain::IntDomain(vec![Range::Bounded(1, 3)])) - .add_var_str("c", Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var(a.clone(), Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var(b.clone(), Domain::IntDomain(vec![Range::Bounded(1, 3)])) + .add_var(c.clone(), Domain::IntDomain(vec![Range::Bounded(1, 3)])) .add_constraint(Expression::Eq( Box::new(Expression::Sum(vec![ Expression::Reference(a.clone()), From 76f258c5b471e2fc24e11dd59352adfbb0913f67 Mon Sep 17 00:00:00 2001 From: Felix Date: Fri, 27 Oct 2023 11:58:07 +0100 Subject: [PATCH 5/5] Assert correct state when adding variables --- src/common/ast/ast.rs | 1 + src/common/ast/model_builder.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/src/common/ast/ast.rs b/src/common/ast/ast.rs index 64c47a107..f1d298e1e 100644 --- a/src/common/ast/ast.rs +++ b/src/common/ast/ast.rs @@ -9,6 +9,7 @@ pub struct Model { impl Model { // Function to update a DecisionVariable based on its Name pub fn update_domain(&mut self, name: &Name, new_domain: Domain) { + assert!(self.variables.contains_key(name)); if let Some(decision_var) = self.variables.get_mut(name) { decision_var.domain = new_domain; } diff --git a/src/common/ast/model_builder.rs b/src/common/ast/model_builder.rs index ad0830833..ac224f8d3 100644 --- a/src/common/ast/model_builder.rs +++ b/src/common/ast/model_builder.rs @@ -20,6 +20,7 @@ impl ModelBuilder { } pub fn add_var(mut self, name: Name, domain: Domain) -> Self { + assert!(self.variables.get(&name).is_none()); self.variables.insert( name, DecisionVariable {