Skip to content

Commit

Permalink
Merge pull request #99 from ozgurakgun/constraint-parsing
Browse files Browse the repository at this point in the history
Constraint parsing
  • Loading branch information
ozgurakgun authored Nov 22, 2023
2 parents cfb1060 + 718d1bc commit e045d91
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 39 deletions.
12 changes: 9 additions & 3 deletions conjure_oxide/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap;
use std::fmt::Display;

#[serde_as]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Model {
#[serde_as(as = "Vec<(_, _)>")]
pub variables: HashMap<Name, DecisionVariable>,
Expand Down Expand Up @@ -36,13 +36,13 @@ impl Default for Model {
}
}

#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Name {
UserName(String),
MachineName(i32),
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DecisionVariable {
pub domain: Domain,
}
Expand Down Expand Up @@ -87,9 +87,15 @@ pub enum Range<A> {
pub enum Expression {
ConstantInt(i32),
Reference(Name),

Sum(Vec<Expression>),

Eq(Box<Expression>, Box<Expression>),
Neq(Box<Expression>, Box<Expression>),
Geq(Box<Expression>, Box<Expression>),
Leq(Box<Expression>, Box<Expression>),
Gt(Box<Expression>, Box<Expression>),
Lt(Box<Expression>, Box<Expression>),

// Flattened Constraints
SumGeq(Vec<Expression>, Box<Expression>),
Expand Down
130 changes: 110 additions & 20 deletions conjure_oxide/src/parse.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use crate::ast::{DecisionVariable, Domain, Model, Name, Range};
use std::collections::HashMap;

use serde_json::Value;

use crate::ast::{DecisionVariable, Domain, Expression, Model, Name, Range};
use crate::error::{Error, Result};
use serde_json::Value as JsonValue;

pub fn parse_json(str: &str) -> Result<Model> {
let mut m = Model::new();
let v: JsonValue = serde_json::from_str(str)?;
let constraints = v["mStatements"]
let statements = v["mStatements"]
.as_array()
.ok_or(Error::Parse("mStatements is not an array".to_owned()))?;

for con in constraints {
let entry = con
for statement in statements {
let entry = statement
.as_object()
.ok_or(Error::Parse("mStatements contains a non-object".to_owned()))?
.iter()
Expand All @@ -23,12 +27,18 @@ pub fn parse_json(str: &str) -> Result<Model> {
let (name, var) = parse_variable(entry.1)?;
m.add_variable(name, var);
}
"SuchThat" => parse_constraint(entry.1)?,
_ => {
return Err(Error::Parse(
"mStatements contains an unknown object".to_owned(),
))
"SuchThat" => {
let constraints: Vec<Expression> = entry
.1
.as_array()
.unwrap()
.iter()
.flat_map(parse_expression)
.collect();
m.constraints.extend(constraints);
// println!("Nb constraints {}", m.constraints.len());
}
otherwise => panic!("Unhandled Statement {:#?}", otherwise),
}
}

Expand Down Expand Up @@ -90,14 +100,13 @@ fn parse_int_domain(v: &JsonValue) -> Result<Domain> {
.as_array()
.ok_or(Error::Parse("RangeBounded is not an array".to_owned()))?;
let mut nums = Vec::new();
for i in 0..2 {
let num =
&arr[i]["Constant"]["ConstantInt"][1]
.as_i64()
.ok_or(Error::Parse(
"Could not parse int domain constant".to_owned(),
))?;
let num32 = i32::try_from(*num).map_err(|_| {
for item in arr.iter() {
let num = item["Constant"]["ConstantInt"][1]
.as_i64()
.ok_or(Error::Parse(
"Could not parse int domain constant".to_owned(),
))?;
let num32 = i32::try_from(num).map_err(|_| {
Error::Parse("Could not parse int domain constant".to_owned())
})?;
nums.push(num32);
Expand All @@ -124,12 +133,93 @@ fn parse_int_domain(v: &JsonValue) -> Result<Domain> {
Ok(Domain::IntDomain(ranges))
}

fn parse_constraint(_obj: &JsonValue) -> Result<()> {
Ok(())
// this needs an explicit type signature to force the closures to have the same type
type BinOp = Box<dyn Fn(Box<Expression>, Box<Expression>) -> Expression>;

fn parse_expression(obj: &JsonValue) -> Option<Expression> {
let binary_operators: HashMap<&str, BinOp> = [
("MkOpEq", Box::new(Expression::Eq) as Box<dyn Fn(_, _) -> _>),
(
"MkOpNeq",
Box::new(Expression::Neq) as Box<dyn Fn(_, _) -> _>,
),
(
"MkOpGeq",
Box::new(Expression::Geq) as Box<dyn Fn(_, _) -> _>,
),
(
"MkOpLeq",
Box::new(Expression::Leq) as Box<dyn Fn(_, _) -> _>,
),
("MkOpGt", Box::new(Expression::Gt) as Box<dyn Fn(_, _) -> _>),
("MkOpLt", Box::new(Expression::Lt) as Box<dyn Fn(_, _) -> _>),
]
.into_iter()
.collect();

let mut binary_operator_names = binary_operators.iter().map(|x| x.0);

match obj {
Value::Object(op) if op.contains_key("Op") => match &op["Op"] {
Value::Object(bin_op) if binary_operator_names.any(|key| bin_op.contains_key(*key)) => {
parse_bin_op(bin_op, binary_operators)
}
Value::Object(op_sum) if op_sum.contains_key("MkOpSum") => parse_sum(op_sum),
otherwise => panic!("Unhandled Op {:#?}", otherwise),
},
Value::Object(refe) if refe.contains_key("Reference") => {
let name = refe["Reference"].as_array()?[0].as_object()?["Name"].as_str()?;
Some(Expression::Reference(Name::UserName(name.to_string())))
}
Value::Object(constant) if constant.contains_key("Constant") => parse_constant(constant),
otherwise => panic!("Unhandled Expression {:#?}", otherwise),
}
}

fn parse_sum(op_sum: &serde_json::Map<String, Value>) -> Option<Expression> {
let args = &op_sum["MkOpSum"]["AbstractLiteral"]["AbsLitMatrix"][1];
let args_parsed: Vec<Expression> = args
.as_array()?
.iter()
.map(|x| parse_expression(x).unwrap())
.collect();
Some(Expression::Sum(args_parsed))
}

fn parse_bin_op(
bin_op: &serde_json::Map<String, Value>,
binary_operators: HashMap<&str, BinOp>,
) -> Option<Expression> {
// we know there is a single key value pair in this object
// extract the value, ignore the key
let (key, value) = bin_op.into_iter().next()?;

let constructor = binary_operators.get(key.as_str())?;

match &value {
Value::Array(bin_op_args) if bin_op_args.len() == 2 => {
let arg1 = parse_expression(&bin_op_args[0])?;
let arg2 = parse_expression(&bin_op_args[1])?;
Some(constructor(Box::new(arg1), Box::new(arg2)))
}
otherwise => panic!("Unhandled parse_bin_op {:#?}", otherwise),
}
}

fn parse_constant(constant: &serde_json::Map<String, Value>) -> Option<Expression> {
match &constant["Constant"] {
Value::Object(int) if int.contains_key("ConstantInt") => Some(Expression::ConstantInt(
int["ConstantInt"].as_array()?[1]
.as_i64()?
.try_into()
.unwrap(),
)),
otherwise => panic!("Unhandled parse_constant {:#?}", otherwise),
}
}

impl Model {
pub fn from_json(str: &String) -> Result<Model> {
pub fn from_json(str: &str) -> Result<Model> {
parse_json(str)
}
}
67 changes: 61 additions & 6 deletions conjure_oxide/tests/generated_tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use conjure_oxide::ast::Model;
use serde_json::Value;
use std::env;
use std::error::Error;
use std::fs::File;
use std::io::prelude::*;

use conjure_oxide::ast::Model;

use std::path::Path;

fn main() {
Expand Down Expand Up @@ -37,10 +37,15 @@ fn integration_test(path: &str, essence_base: &str) -> Result<(), Box<dyn Error>
// "parsing" astjson as Model
let generated_mdl = Model::from_json(&astjson)?;

// a consistent sorting of the keys of json objects
// only required for the generated version
// since the expected version will already be sorted
let generated_json = sort_json_object(&serde_json::to_value(generated_mdl.clone())?);

// serialise to file
let generated_json = serde_json::to_string_pretty(&generated_mdl)?;
let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
File::create(format!("{path}/{essence_base}.generated.serialised.json"))?
.write_all(generated_json.as_bytes())?;
.write_all(generated_json_str.as_bytes())?;

if std::env::var("ACCEPT").map_or(false, |v| v == "true") {
std::fs::copy(
Expand All @@ -55,8 +60,7 @@ fn integration_test(path: &str, essence_base: &str) -> Result<(), Box<dyn Error>
let expected_str =
std::fs::read_to_string(format!("{path}/{essence_base}.expected.serialised.json"))?;

let mut expected_mdl: Model = serde_json::from_str(&expected_str)?;
expected_mdl.constraints = Vec::new(); // TODO - remove this line once we parse constraints
let expected_mdl: Model = serde_json::from_str(&expected_str)?;

// --------------------------------------------------------------------------------
// assert that they are the same model
Expand All @@ -66,4 +70,55 @@ fn integration_test(path: &str, essence_base: &str) -> Result<(), Box<dyn Error>
Ok(())
}

/// Recursively sorts the keys of all JSON objects within the provided JSON value.
///
/// serde_json will output JSON objects in an arbitrary key order.
/// this is normally fine, except in our use case we wouldn't want to update the expected output again and again.
/// so a consistent (sorted) ordering of the keys is desirable.
fn sort_json_object(value: &Value) -> Value {
match value {
Value::Object(obj) => {
let mut ordered: Vec<(String, Value)> = obj
.iter()
.map(|(k, v)| {
if k == "variables" {
(k.clone(), sort_json_variables(v))
} else {
(k.clone(), sort_json_object(v))
}
})
// .map(|(k, v)| (k.clone(), sort_json_object(v)))
.collect();
ordered.sort_by(|a, b| a.0.cmp(&b.0));

Value::Object(ordered.into_iter().collect())
}
Value::Array(arr) => Value::Array(arr.iter().map(sort_json_object).collect()),
_ => value.clone(),
}
}

/// Sort the "variables" field by name.
/// We have to do this separately becasue that field is not a JSON object, instead it's an array of tuples.
fn sort_json_variables(value: &Value) -> Value {
match value {
Value::Array(vars) => {
let mut vars_sorted = vars.clone();
vars_sorted.sort_by(|a, b| {
let a_obj = &a.as_array().unwrap()[0];
let a_name: conjure_oxide::ast::Name =
serde_json::from_value(a_obj.clone()).unwrap();

let b_obj = &b.as_array().unwrap()[0];
let b_name: conjure_oxide::ast::Name =
serde_json::from_value(b_obj.clone()).unwrap();

a_name.cmp(&b_name)
});
Value::Array(vars_sorted)
}
_ => value.clone(),
}
}

include!(concat!(env!("OUT_DIR"), "/gen_tests.rs"));
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"constraints": [],
"variables": [
[
{
Expand All @@ -8,6 +9,5 @@
"domain": "BoolDomain"
}
]
],
"constraints": []
]
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
{
"constraints": [],
"variables": [
[
{
"UserName": "y"
"UserName": "x"
},
{
"domain": "BoolDomain"
}
],
[
{
"UserName": "x"
"UserName": "y"
},
{
"domain": "BoolDomain"
}
]
],
"constraints": []
]
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
{
"constraints": [
{
"Neq": [
{
"Reference": {
"UserName": "x"
}
},
{
"Reference": {
"UserName": "y"
}
}
]
}
],
"variables": [
[
{
Expand All @@ -16,6 +32,5 @@
"domain": "BoolDomain"
}
]
],
"constraints": []
]
}
Loading

0 comments on commit e045d91

Please sign in to comment.