From 246b92a72db38a4fbcb3a749409e6be37a2669b5 Mon Sep 17 00:00:00 2001 From: Michael Bynum Date: Tue, 31 Oct 2023 07:39:59 -0600 Subject: [PATCH] evaluation error detection --- idaes/core/util/model_diagnostics.py | 62 ++++++++++---- .../core/util/tests/test_model_diagnostics.py | 81 +++++++++++++++++++ 2 files changed, 126 insertions(+), 17 deletions(-) diff --git a/idaes/core/util/model_diagnostics.py b/idaes/core/util/model_diagnostics.py index 5d20656119..27bfa68a60 100644 --- a/idaes/core/util/model_diagnostics.py +++ b/idaes/core/util/model_diagnostics.py @@ -266,6 +266,14 @@ def svd_sparse(jacobian, number_singular_values): description="Tolerance for raising a warning for small Jacobian values.", ), ) +CONFIG.declare( + "strict_evaluation_error_detection", + ConfigValue( + default=True, + domain=bool, + description="If False, warnings will not be generated for things like log(x) with x >= 0", + ), +) SVDCONFIG = ConfigDict() @@ -1331,7 +1339,7 @@ def _collect_potential_eval_errors(self) -> List[str]: for con in self.model.component_data_objects( Constraint, active=True, descend_into=True ): - walker = _EvalErrorWalker() + walker = _EvalErrorWalker(self.config) con_warnings = walker.walk_expression(con.body) for msg in con_warnings: msg = f"{con.name}: " + msg @@ -1339,7 +1347,7 @@ def _collect_potential_eval_errors(self) -> List[str]: for obj in self.model.component_data_objects( Objective, active=True, descend_into=True ): - walker = _EvalErrorWalker() + walker = _EvalErrorWalker(self.config) obj_warnings = walker.walk_expression(obj.expr) for msg in obj_warnings: msg = f"{obj.name}: " + msg @@ -1690,14 +1698,18 @@ def _get_bounds_with_inf(node: NumericExpression): return lb, ub -def _check_eval_error_division(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_division( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): lb, ub = _get_bounds_with_inf(node.args[1]) - if lb <= 0 <= ub: + if (config.strict_evaluation_error_detection and (lb <= 0 <= ub)) or (lb < 0 < ub): msg = f"Potential division by 0 in {node}; Denominator bounds are ({lb}, {ub})" warn_list.append(msg) -def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_pow( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): arg1, arg2 = node.args lb1, ub1 = _get_bounds_with_inf(arg1) lb2, ub2 = _get_bounds_with_inf(arg2) @@ -1727,7 +1739,10 @@ def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]): # only integer variables with integer coefficients integer_exponent = True - if integer_exponent and (lb1 > 0 or ub1 < 0): + if integer_exponent and ( + (lb1 > 0 or ub1 < 0) + or (not config.strict_evaluation_error_detection and (lb1 >= 0 or ub1 <= 0)) + ): # life is good; the exponent is an integer and the base is nonzero return None elif integer_exponent and lb2 >= 0: @@ -1735,7 +1750,7 @@ def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]): return None # if the base is positive, there should not be any evaluation errors - if lb1 > 0: + if lb1 > 0 or (not config.strict_evaluation_error_detection and lb1 >= 0): return None if lb1 >= 0 and lb2 >= 0: return None @@ -1746,35 +1761,45 @@ def _check_eval_error_pow(node: NumericExpression, warn_list: List[str]): warn_list.append(msg) -def _check_eval_error_log(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_log( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): lb, ub = _get_bounds_with_inf(node.args[0]) - if lb <= 0: + if (config.strict_evaluation_error_detection and lb <= 0) or lb < 0: msg = f"Potential log of a non-positive number in {node}; Argument bounds are ({lb}, {ub})" warn_list.append(msg) -def _check_eval_error_tan(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_tan( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): lb, ub = _get_bounds_with_inf(node) if not (math.isfinite(lb) and math.isfinite(ub)): msg = f"{node} may evaluate to -inf or inf; Argument bounds are {_get_bounds_with_inf(node.args[0])}" warn_list.append(msg) -def _check_eval_error_asin(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_asin( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): lb, ub = _get_bounds_with_inf(node.args[0]) if lb < -1 or ub > 1: msg = f"Potential evaluation of asin outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})" warn_list.append(msg) -def _check_eval_error_acos(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_acos( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): lb, ub = _get_bounds_with_inf(node.args[0]) if lb < -1 or ub > 1: msg = f"Potential evaluation of acos outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})" warn_list.append(msg) -def _check_eval_error_sqrt(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_sqrt( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): lb, ub = _get_bounds_with_inf(node.args[0]) if lb < 0: msg = f"Potential square root of a negative number in {node}; Argument bounds are ({lb}, {ub})" @@ -1790,9 +1815,11 @@ def _check_eval_error_sqrt(node: NumericExpression, warn_list: List[str]): _unary_eval_err_handler["sqrt"] = _check_eval_error_sqrt -def _check_eval_error_unary(node: NumericExpression, warn_list: List[str]): +def _check_eval_error_unary( + node: NumericExpression, warn_list: List[str], config: ConfigDict +): if node.getname() in _unary_eval_err_handler: - _unary_eval_err_handler[node.getname()](node, warn_list) + _unary_eval_err_handler[node.getname()](node, warn_list, config) _eval_err_handler = dict() @@ -1805,13 +1832,14 @@ def _check_eval_error_unary(node: NumericExpression, warn_list: List[str]): class _EvalErrorWalker(StreamBasedExpressionVisitor): - def __init__(self): + def __init__(self, config: ConfigDict): super().__init__() self._warn_list = list() + self._config = config def exitNode(self, node, data): if type(node) in _eval_err_handler: - _eval_err_handler[type(node)](node, self._warn_list) + _eval_err_handler[type(node)](node, self._warn_list, self._config) return self._warn_list diff --git a/idaes/core/util/tests/test_model_diagnostics.py b/idaes/core/util/tests/test_model_diagnostics.py index 663136c194..b86f35dd1b 100644 --- a/idaes/core/util/tests/test_model_diagnostics.py +++ b/idaes/core/util/tests/test_model_diagnostics.py @@ -2600,6 +2600,18 @@ def test_div(self): warnings = dtb._collect_potential_eval_errors() self.assertEqual(len(warnings), 0) + m.x.setlb(0) + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 1) + w = warnings[0] + self.assertEqual( + w, "c: Potential division by 0 in 1/x; Denominator bounds are (0, inf)" + ) + + dtb.config.strict_evaluation_error_detection = False + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 0) + m.x.setlb(-1) warnings = dtb._collect_potential_eval_errors() self.assertEqual(len(warnings), 1) @@ -2663,6 +2675,62 @@ def test_pow3(self): warnings = dtb._collect_potential_eval_errors() self.assertEqual(len(warnings), 0) + @pytest.mark.unit + def test_pow4(self): + m = ConcreteModel() + m.x = Var(bounds=(0, None)) + m.y = Var() + m.c = Constraint(expr=m.y == m.x ** (-2)) + dtb = DiagnosticsToolbox(m) + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 1) + w = warnings[0] + self.assertEqual( + w, + "c: Potential evaluation error in x**-2; base bounds are (0, inf); exponent bounds are (-2, -2)", + ) + + dtb.config.strict_evaluation_error_detection = False + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 0) + + m.x.setlb(-1) + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 1) + w = warnings[0] + self.assertEqual( + w, + "c: Potential evaluation error in x**-2; base bounds are (-1, inf); exponent bounds are (-2, -2)", + ) + + @pytest.mark.unit + def test_pow5(self): + m = ConcreteModel() + m.x = Var(bounds=(0, None)) + m.y = Var() + m.c = Constraint(expr=m.y == m.x ** (-2.5)) + dtb = DiagnosticsToolbox(m) + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 1) + w = warnings[0] + self.assertEqual( + w, + "c: Potential evaluation error in x**-2.5; base bounds are (0, inf); exponent bounds are (-2.5, -2.5)", + ) + + dtb.config.strict_evaluation_error_detection = False + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 0) + + m.x.setlb(-1) + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 1) + w = warnings[0] + self.assertEqual( + w, + "c: Potential evaluation error in x**-2.5; base bounds are (-1, inf); exponent bounds are (-2.5, -2.5)", + ) + @pytest.mark.unit def test_log(self): m = ConcreteModel() @@ -2673,6 +2741,19 @@ def test_log(self): warnings = dtb._collect_potential_eval_errors() self.assertEqual(len(warnings), 0) + m.x.setlb(0) + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 1) + w = warnings[0] + self.assertEqual( + w, + "c: Potential log of a non-positive number in log(x); Argument bounds are (0, inf)", + ) + + dtb.config.strict_evaluation_error_detection = False + warnings = dtb._collect_potential_eval_errors() + self.assertEqual(len(warnings), 0) + m.x.setlb(-1) warnings = dtb._collect_potential_eval_errors() self.assertEqual(len(warnings), 1)