Skip to content

Commit

Permalink
evaluation error detection
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbynum committed Oct 31, 2023
1 parent e9e8fc8 commit 246b92a
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 17 deletions.
62 changes: 45 additions & 17 deletions idaes/core/util/model_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ def svd_sparse(jacobian, number_singular_values):
description="Tolerance for raising a warning for small Jacobian values.",
),
)
CONFIG.declare(
"strict_evaluation_error_detection",
ConfigValue(
default=True,
domain=bool,
description="If False, warnings will not be generated for things like log(x) with x >= 0",
),
)


SVDCONFIG = ConfigDict()
Expand Down Expand Up @@ -1331,15 +1339,15 @@ def _collect_potential_eval_errors(self) -> List[str]:
for con in self.model.component_data_objects(
Constraint, active=True, descend_into=True
):
walker = _EvalErrorWalker()
walker = _EvalErrorWalker(self.config)
con_warnings = walker.walk_expression(con.body)
for msg in con_warnings:
msg = f"{con.name}: " + msg
warnings.append(msg)
for obj in self.model.component_data_objects(
Objective, active=True, descend_into=True
):
walker = _EvalErrorWalker()
walker = _EvalErrorWalker(self.config)
obj_warnings = walker.walk_expression(obj.expr)
for msg in obj_warnings:
msg = f"{obj.name}: " + msg
Expand Down Expand Up @@ -1690,14 +1698,18 @@ def _get_bounds_with_inf(node: NumericExpression):
return lb, ub


def _check_eval_error_division(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_division(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[1])
if lb <= 0 <= ub:
if (config.strict_evaluation_error_detection and (lb <= 0 <= ub)) or (lb < 0 < ub):
msg = f"Potential division by 0 in {node}; Denominator bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_pow(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
arg1, arg2 = node.args
lb1, ub1 = _get_bounds_with_inf(arg1)
lb2, ub2 = _get_bounds_with_inf(arg2)
Expand Down Expand Up @@ -1727,15 +1739,18 @@ def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]):
# only integer variables with integer coefficients
integer_exponent = True

if integer_exponent and (lb1 > 0 or ub1 < 0):
if integer_exponent and (
(lb1 > 0 or ub1 < 0)
or (not config.strict_evaluation_error_detection and (lb1 >= 0 or ub1 <= 0))
):
# life is good; the exponent is an integer and the base is nonzero
return None
elif integer_exponent and lb2 >= 0:
# life is good; the exponent is a nonnegative integer
return None

# if the base is positive, there should not be any evaluation errors
if lb1 > 0:
if lb1 > 0 or (not config.strict_evaluation_error_detection and lb1 >= 0):
return None
if lb1 >= 0 and lb2 >= 0:
return None
Expand All @@ -1746,35 +1761,45 @@ def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]):
warn_list.append(msg)


def _check_eval_error_log(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_log(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb <= 0:
if (config.strict_evaluation_error_detection and lb <= 0) or lb < 0:
msg = f"Potential log of a non-positive number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_tan(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_tan(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node)
if not (math.isfinite(lb) and math.isfinite(ub)):
msg = f"{node} may evaluate to -inf or inf; Argument bounds are {_get_bounds_with_inf(node.args[0])}"
warn_list.append(msg)


def _check_eval_error_asin(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_asin(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of asin outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_acos(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_acos(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of acos outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_sqrt(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_sqrt(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < 0:
msg = f"Potential square root of a negative number in {node}; Argument bounds are ({lb}, {ub})"
Expand All @@ -1790,9 +1815,11 @@ def _check_eval_error_sqrt(node: NumericExpression, warn_list: List[str]):
_unary_eval_err_handler["sqrt"] = _check_eval_error_sqrt


def _check_eval_error_unary(node: NumericExpression, warn_list: List[str]):
def _check_eval_error_unary(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
if node.getname() in _unary_eval_err_handler:
_unary_eval_err_handler[node.getname()](node, warn_list)
_unary_eval_err_handler[node.getname()](node, warn_list, config)


_eval_err_handler = dict()
Expand All @@ -1805,13 +1832,14 @@ def _check_eval_error_unary(node: NumericExpression, warn_list: List[str]):


class _EvalErrorWalker(StreamBasedExpressionVisitor):
def __init__(self):
def __init__(self, config: ConfigDict):
super().__init__()
self._warn_list = list()
self._config = config

def exitNode(self, node, data):
if type(node) in _eval_err_handler:
_eval_err_handler[type(node)](node, self._warn_list)
_eval_err_handler[type(node)](node, self._warn_list, self._config)
return self._warn_list


Expand Down
81 changes: 81 additions & 0 deletions idaes/core/util/tests/test_model_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,18 @@ def test_div(self):
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

m.x.setlb(0)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
w = warnings[0]
self.assertEqual(
w, "c: Potential division by 0 in 1/x; Denominator bounds are (0, inf)"
)

dtb.config.strict_evaluation_error_detection = False
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

m.x.setlb(-1)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
Expand Down Expand Up @@ -2663,6 +2675,62 @@ def test_pow3(self):
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

@pytest.mark.unit
def test_pow4(self):
m = ConcreteModel()
m.x = Var(bounds=(0, None))
m.y = Var()
m.c = Constraint(expr=m.y == m.x ** (-2))
dtb = DiagnosticsToolbox(m)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
w = warnings[0]
self.assertEqual(
w,
"c: Potential evaluation error in x**-2; base bounds are (0, inf); exponent bounds are (-2, -2)",
)

dtb.config.strict_evaluation_error_detection = False
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

m.x.setlb(-1)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
w = warnings[0]
self.assertEqual(
w,
"c: Potential evaluation error in x**-2; base bounds are (-1, inf); exponent bounds are (-2, -2)",
)

@pytest.mark.unit
def test_pow5(self):
m = ConcreteModel()
m.x = Var(bounds=(0, None))
m.y = Var()
m.c = Constraint(expr=m.y == m.x ** (-2.5))
dtb = DiagnosticsToolbox(m)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
w = warnings[0]
self.assertEqual(
w,
"c: Potential evaluation error in x**-2.5; base bounds are (0, inf); exponent bounds are (-2.5, -2.5)",
)

dtb.config.strict_evaluation_error_detection = False
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

m.x.setlb(-1)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
w = warnings[0]
self.assertEqual(
w,
"c: Potential evaluation error in x**-2.5; base bounds are (-1, inf); exponent bounds are (-2.5, -2.5)",
)

@pytest.mark.unit
def test_log(self):
m = ConcreteModel()
Expand All @@ -2673,6 +2741,19 @@ def test_log(self):
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

m.x.setlb(0)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
w = warnings[0]
self.assertEqual(
w,
"c: Potential log of a non-positive number in log(x); Argument bounds are (0, inf)",
)

dtb.config.strict_evaluation_error_detection = False
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 0)

m.x.setlb(-1)
warnings = dtb._collect_potential_eval_errors()
self.assertEqual(len(warnings), 1)
Expand Down

0 comments on commit 246b92a

Please sign in to comment.