Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diagnostic tool to look for potential evaluation errors #1268

Merged
merged 15 commits into from
Nov 15, 2023
Merged
214 changes: 213 additions & 1 deletion idaes/core/util/model_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from operator import itemgetter
from sys import stdout
import math
from math import log
from typing import List, Sequence

import numpy as np
from scipy.linalg import svd
Expand All @@ -28,6 +30,7 @@

from pyomo.environ import (
Binary,
Integers,
Block,
check_optimal_termination,
ConcreteModel,
Expand All @@ -38,15 +41,28 @@
SolverFactory,
value,
Var,
is_fixed,
)
from pyomo.core.expr.numeric_expr import (
DivisionExpression,
NPV_DivisionExpression,
PowExpression,
NPV_PowExpression,
UnaryFunctionExpression,
NPV_UnaryFunctionExpression,
NumericExpression,
)
from pyomo.core.base.block import _BlockData
from pyomo.core.base.var import _GeneralVarData
from pyomo.repn.standard_repn import generate_standard_repn
from pyomo.common.collections import ComponentSet
from pyomo.common.config import ConfigDict, ConfigValue, document_kwargs_from_configdict
from pyomo.util.check_units import identify_inconsistent_units
from pyomo.contrib.incidence_analysis import IncidenceGraphInterface
from pyomo.core.expr.visitor import identify_variables
from pyomo.core.expr.visitor import identify_variables, StreamBasedExpressionVisitor
from pyomo.contrib.pynumero.interfaces.pyomo_nlp import PyomoNLP
from pyomo.contrib.pynumero.asl import AmplInterface
from pyomo.contrib.fbbt.fbbt import compute_bounds_on_expr
from pyomo.common.deprecation import deprecation_warning

from idaes.core.util.model_statistics import (
Expand Down Expand Up @@ -1045,6 +1061,202 @@ def report_numerical_issues(self, stream=stdout):
footer="=",
)

def _collect_potential_eval_errors(self):
res = list()
warnings = list()
cautions = list()
for con in self.model.component_data_objects(Constraint, active=True, descend_into=True):
walker = _EvalErrorWalker()
con_warnings, con_cautions = walker.walk_expression(con.body)
for msg in con_warnings:
msg = f'{con.name}: ' + msg
warnings.append(msg)
for msg in con_cautions:
msg = f'{con.name}: ' + msg
cautions.append(msg)
for obj in self.model.component_data_objects(Objective, active=True, descend_into=True):
walker = _EvalErrorWalker()
obj_warnings, obj_cautions = walker.walk_expression(obj.expr)
for msg in obj_warnings:
msg = f'{obj.name}: ' + msg
warnings.append(msg)
for msg in obj_cautions:
msg = f'{obj.name}: ' + msg
cautions.append(msg)

return warnings, cautions

def report_potential_evaluation_errors(self, stream=stdout):
warnings, cautions = self._collect_potential_eval_errors()
_write_report_section(
stream=stream,
lines_list=warnings,
title=f"{len(warnings)} WARNINGS",
line_if_empty="No warnings found!",
header="=",
)
_write_report_section(
stream=stream,
lines_list=cautions,
title=f"{len(cautions)} Cautions",
line_if_empty="No cautions found!",
footer="=",
)


def _get_bounds_with_inf(node: NumericExpression):
lb, ub = compute_bounds_on_expr(node)
if lb is None:
lb = -math.inf
if ub is None:
ub = math.inf
return lb, ub


def _caution_expression_argument(
node: NumericExpression,
args_to_check: Sequence[NumericExpression],
caution_list: List[str]
):
should_caution = False
for arg in args_to_check:
if is_fixed(arg):
continue
if isinstance(arg, _GeneralVarData):
continue
should_caution = True
break
if should_caution:
msg = f'Potential evaluation error in {node}; '
msg += 'arguments are expressions with bounds that are not strictly '
msg += 'enforced; try making the argument a variable'
caution_list.append(msg)


def _check_eval_error_division(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
_caution_expression_argument(node, [node.args[1]], caution_list)
lb, ub = _get_bounds_with_inf(node.args[1])
if lb <= 0 <= ub:
msg = f'Potential division by 0 in {node}; Denominator bounds are ({lb}, {ub})'
warn_list.append(msg)


def _check_eval_error_pow(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
arg1, arg2 = node.args

integer_domains = ComponentSet([Binary, Integers])

# if the exponent is an integer, there should not be any evaluation errors
if isinstance(arg2, _GeneralVarData) and arg2.domain in integer_domains:
# life is good. The exponent is an integer variable
return None
michaelbynum marked this conversation as resolved.
Show resolved Hide resolved
lb2, ub2 = _get_bounds_with_inf(arg2)
if lb2 == ub2 and lb2 == round(lb2):
# life is good. The exponent is fixed to an integer
michaelbynum marked this conversation as resolved.
Show resolved Hide resolved
return None
repn = generate_standard_repn(arg2, quadratic=True)
if (
repn.nonlinear_expr is None
and repn.constant == round(repn.constant)
and all(i.domain in integer_domains for i in repn.linear_vars)
and all(i[0].domain in integer_domains for i in repn.quadratic_vars)
and all(i[1].domain in integer_domains for i in repn.quadratic_vars)
and all(i == round(i) for i in repn.linear_coefs)
and all(i == round(i) for i in repn.quadratic_coefs)
):
# Life is good. The exponent is a linear or quadratic expression containing
# only integer variables with integer coefficients
return None
michaelbynum marked this conversation as resolved.
Show resolved Hide resolved

_caution_expression_argument(node, node.args, caution_list)

# if the base is positive, there should not be any evaluation errors
lb1, ub1 = _get_bounds_with_inf(arg1)
if lb1 > 0:
return None
if lb1 >= 0 and lb2 >= 0:
return None

msg = f'Potential evaluation error in {node}; '
msg += f'base bounds are ({lb1}, {ub1}); '
msg += f'exponent bounds are ({lb2}, {ub2})'
warn_list.append(msg)


def _check_eval_error_log(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
_caution_expression_argument(node, node.args, caution_list)
lb, ub = _get_bounds_with_inf(node.args[0])
if lb <= 0:
msg = f'Potential log of a negative number in {node}; Argument bounds are ({lb}, {ub})'
michaelbynum marked this conversation as resolved.
Show resolved Hide resolved
warn_list.append(msg)


def _check_eval_error_tan(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
_caution_expression_argument(node, node.args, caution_list)
lb, ub = _get_bounds_with_inf(node)
if not (math.isfinite(lb) and math.isfinite(ub)):
msg = f'{node} may evaluate to -inf or inf; Argument bounds are {_get_bounds_with_inf(node.args[0])}'
warn_list.append(msg)


def _check_eval_error_asin(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
_caution_expression_argument(node, node.args, caution_list)
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f'Potential evaluation of asin outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})'
warn_list.append(msg)


def _check_eval_error_acos(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
_caution_expression_argument(node, node.args, caution_list)
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f'Potential evaluation of acos outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})'
warn_list.append(msg)


def _check_eval_error_sqrt(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
_caution_expression_argument(node, node.args, caution_list)
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < 0:
msg = f'Potential square root of a negative number in {node}; Argument bounds are ({lb}, {ub})'
warn_list.append(msg)


_unary_eval_err_handler = dict()
_unary_eval_err_handler['log'] = _check_eval_error_log
_unary_eval_err_handler['log10'] = _check_eval_error_log
_unary_eval_err_handler['tan'] = _check_eval_error_tan
_unary_eval_err_handler['asin'] = _check_eval_error_asin
_unary_eval_err_handler['acos'] = _check_eval_error_acos
_unary_eval_err_handler['sqrt'] = _check_eval_error_sqrt


def _check_eval_error_unary(node: NumericExpression, warn_list: List[str], caution_list: List[str]):
if node.getname() in _unary_eval_err_handler:
_unary_eval_err_handler[node.getname()](node, warn_list, caution_list)


_eval_err_handler = dict()
_eval_err_handler[DivisionExpression] = _check_eval_error_division
_eval_err_handler[NPV_DivisionExpression] = _check_eval_error_division
_eval_err_handler[PowExpression] = _check_eval_error_pow
_eval_err_handler[NPV_PowExpression] = _check_eval_error_pow
_eval_err_handler[UnaryFunctionExpression] = _check_eval_error_unary
_eval_err_handler[NPV_UnaryFunctionExpression] = _check_eval_error_unary


class _EvalErrorWalker(StreamBasedExpressionVisitor):
def __init__(self):
super().__init__()
self._warn_list = list()
self._caution_list = list()

def exitNode(self, node, data):
if type(node) in _eval_err_handler:
_eval_err_handler[type(node)](node, self._warn_list, self._caution_list)
return self._warn_list, self._caution_list


class DegeneracyHunter:
"""
Expand Down