Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diagnostic tool to look for potential evaluation errors #1268

Merged
merged 15 commits into from
Nov 15, 2023
Merged
244 changes: 242 additions & 2 deletions idaes/core/util/model_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from operator import itemgetter
import sys
from inspect import signature
import math
from math import log
from typing import List

import numpy as np
from scipy.linalg import svd
Expand All @@ -29,6 +31,7 @@

from pyomo.environ import (
Binary,
Integers,
Block,
check_optimal_termination,
ConcreteModel,
Expand All @@ -40,9 +43,21 @@
value,
Var,
)
from pyomo.core.expr.numeric_expr import (
DivisionExpression,
NPV_DivisionExpression,
PowExpression,
NPV_PowExpression,
UnaryFunctionExpression,
NPV_UnaryFunctionExpression,
NumericExpression,
)
from pyomo.core.base.block import _BlockData
from pyomo.core.base.var import _VarData
from pyomo.core.base.var import _GeneralVarData, _VarData
from pyomo.core.base.constraint import _ConstraintData
from pyomo.repn.standard_repn import ( # pylint: disable=no-name-in-module
generate_standard_repn,
)
from pyomo.common.collections import ComponentSet
from pyomo.common.config import (
ConfigDict,
Expand All @@ -52,9 +67,10 @@
)
from pyomo.util.check_units import identify_inconsistent_units
from pyomo.contrib.incidence_analysis import IncidenceGraphInterface
from pyomo.core.expr.visitor import identify_variables
from pyomo.core.expr.visitor import identify_variables, StreamBasedExpressionVisitor
from pyomo.contrib.pynumero.interfaces.pyomo_nlp import PyomoNLP
from pyomo.contrib.pynumero.asl import AmplInterface
from pyomo.contrib.fbbt.fbbt import compute_bounds_on_expr
from pyomo.common.deprecation import deprecation_warning
from pyomo.common.errors import PyomoException

Expand Down Expand Up @@ -251,6 +267,14 @@ def svd_sparse(jacobian, number_singular_values):
description="Tolerance for raising a warning for small Jacobian values.",
),
)
CONFIG.declare(
"warn_for_evaluation_error_at_bounds",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Robbybp, @andrewlee94 - what do you think of this name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Thank you both.

ConfigValue(
default=True,
domain=bool,
description="If False, warnings will not be generated for things like log(x) with x >= 0",
),
)


SVDCONFIG = ConfigDict()
Expand Down Expand Up @@ -936,6 +960,13 @@ def _collect_structural_warnings(self):
if any(len(x) > 0 for x in [oc_var, oc_con]):
next_steps.append(self.display_overconstrained_set.__name__ + "()")

eval_warnings = self._collect_potential_eval_errors()
if len(eval_warnings) > 0:
warnings.append(
f"WARNING: Found {len(eval_warnings)} potential evaluation errors."
)
next_steps.append(self.display_potential_evaluation_errors.__name__ + "()")

return warnings, next_steps

def _collect_structural_cautions(self):
Expand Down Expand Up @@ -1303,6 +1334,51 @@ def report_numerical_issues(self, stream=None):
footer="=",
)

def _collect_potential_eval_errors(self) -> List[str]:
warnings = list()
for con in self.model.component_data_objects(
Constraint, active=True, descend_into=True
):
walker = _EvalErrorWalker(self.config)
con_warnings = walker.walk_expression(con.body)
for msg in con_warnings:
msg = f"{con.name}: " + msg
warnings.append(msg)
for obj in self.model.component_data_objects(
Objective, active=True, descend_into=True
):
walker = _EvalErrorWalker(self.config)
obj_warnings = walker.walk_expression(obj.expr)
for msg in obj_warnings:
msg = f"{obj.name}: " + msg
warnings.append(msg)

return warnings

def display_potential_evaluation_errors(self, stream=None):
"""
Prints constraints that may be prone to evaluation errors
(e.g., log of a negative number) based on variable bounds.

Args:
stream: an I/O object to write the output to (default = stdout)

Returns:
None
"""
if stream is None:
stream = sys.stdout

warnings = self._collect_potential_eval_errors()
_write_report_section(
stream=stream,
lines_list=warnings,
title=f"{len(warnings)} WARNINGS",
line_if_empty="No warnings found!",
header="=",
footer="=",
)

@document_kwargs_from_configdict(SVDCONFIG)
def prepare_svd_toolbox(self, **kwargs):
"""
Expand Down Expand Up @@ -1623,6 +1699,170 @@ def display_variables_in_constraint(self, constraint, stream=None):
)


def _get_bounds_with_inf(node: NumericExpression):
lb, ub = compute_bounds_on_expr(node)
if lb is None:
lb = -math.inf
if ub is None:
ub = math.inf
return lb, ub


def _check_eval_error_division(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[1])
if (config.warn_for_evaluation_error_at_bounds and (lb <= 0 <= ub)) or (
lb < 0 < ub
):
msg = f"Potential division by 0 in {node}; Denominator bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_pow(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
arg1, arg2 = node.args
lb1, ub1 = _get_bounds_with_inf(arg1)
lb2, ub2 = _get_bounds_with_inf(arg2)

integer_domains = ComponentSet([Binary, Integers])

integer_exponent = False
# if the exponent is an integer, there should not be any evaluation errors
if isinstance(arg2, _GeneralVarData) and arg2.domain in integer_domains:
# The exponent is an integer variable
# check if the base can be zero
integer_exponent = True
if lb2 == ub2 and lb2 == round(lb2):
# The exponent is fixed to an integer
integer_exponent = True
repn = generate_standard_repn(arg2, quadratic=True)
if (
repn.nonlinear_expr is None
and repn.constant == round(repn.constant)
and all(i.domain in integer_domains for i in repn.linear_vars)
and all(i[0].domain in integer_domains for i in repn.quadratic_vars)
and all(i[1].domain in integer_domains for i in repn.quadratic_vars)
and all(i == round(i) for i in repn.linear_coefs)
and all(i == round(i) for i in repn.quadratic_coefs)
):
# The exponent is a linear or quadratic expression containing
# only integer variables with integer coefficients
integer_exponent = True

if integer_exponent and (
(lb1 > 0 or ub1 < 0)
or (not config.warn_for_evaluation_error_at_bounds and (lb1 >= 0 or ub1 <= 0))
):
# life is good; the exponent is an integer and the base is nonzero
return None
elif integer_exponent and lb2 >= 0:
# life is good; the exponent is a nonnegative integer
return None

# if the base is positive, there should not be any evaluation errors
if lb1 > 0 or (not config.warn_for_evaluation_error_at_bounds and lb1 >= 0):
return None
if lb1 >= 0 and lb2 >= 0:
return None

msg = f"Potential evaluation error in {node}; "
msg += f"base bounds are ({lb1}, {ub1}); "
msg += f"exponent bounds are ({lb2}, {ub2})"
warn_list.append(msg)


def _check_eval_error_log(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if (config.warn_for_evaluation_error_at_bounds and lb <= 0) or lb < 0:
msg = f"Potential log of a non-positive number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_tan(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node)
if not (math.isfinite(lb) and math.isfinite(ub)):
msg = f"{node} may evaluate to -inf or inf; Argument bounds are {_get_bounds_with_inf(node.args[0])}"
warn_list.append(msg)


def _check_eval_error_asin(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of asin outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_acos(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of acos outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_sqrt(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < 0:
msg = f"Potential square root of a negative number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


_unary_eval_err_handler = dict()
_unary_eval_err_handler["log"] = _check_eval_error_log
_unary_eval_err_handler["log10"] = _check_eval_error_log
_unary_eval_err_handler["tan"] = _check_eval_error_tan
_unary_eval_err_handler["asin"] = _check_eval_error_asin
_unary_eval_err_handler["acos"] = _check_eval_error_acos
_unary_eval_err_handler["sqrt"] = _check_eval_error_sqrt


def _check_eval_error_unary(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
if node.getname() in _unary_eval_err_handler:
_unary_eval_err_handler[node.getname()](node, warn_list, config)


_eval_err_handler = dict()
_eval_err_handler[DivisionExpression] = _check_eval_error_division
_eval_err_handler[NPV_DivisionExpression] = _check_eval_error_division
_eval_err_handler[PowExpression] = _check_eval_error_pow
_eval_err_handler[NPV_PowExpression] = _check_eval_error_pow
_eval_err_handler[UnaryFunctionExpression] = _check_eval_error_unary
_eval_err_handler[NPV_UnaryFunctionExpression] = _check_eval_error_unary


class _EvalErrorWalker(StreamBasedExpressionVisitor):
def __init__(self, config: ConfigDict):
super().__init__()
self._warn_list = list()
self._config = config

def exitNode(self, node, data):
"""
callback to be called as the visitor moves from the leaf
nodes back to the root node.

Args:
node: a pyomo expression node
data: not used in this walker
"""
if type(node) in _eval_err_handler:
_eval_err_handler[type(node)](node, self._warn_list, self._config)
return self._warn_list


# TODO: Rename and redirect once old DegeneracyHunter is removed
@document_kwargs_from_configdict(DHCONFIG)
class DegeneracyHunter2:
Expand Down
Loading
Loading