Skip to content

Commit

Permalink
Add flag to CheckParameterValue to inform whether it can be replace…
Browse files Browse the repository at this point in the history
…d by -inf
  • Loading branch information
ricardoV94 committed Mar 13, 2023
1 parent 5becc49 commit f043ad9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 19 deletions.
23 changes: 16 additions & 7 deletions pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@
}


def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = ""):
"""
Wrap a log probability graph in a CheckParameterValue that asserts several
conditions are True. When conditions are not met a ParameterValueError assertion is
raised, with an optional custom message defined by `msg`
def check_parameters(
expr: Variable,
*conditions: Iterable[Variable],
msg: str = "",
can_be_replaced_by_ninf: bool = True,
):
"""Wrap an expression in a CheckParameterValue that asserts several conditions are met.
When conditions are not met a ParameterValueError assertion is raised,
with an optional custom message defined by `msg`.
Note that check_parameter should not be used to enforce the logic of the logp
When the flag `can_be_replaced_by_ninf` is True (default), PyMC is allowed to replace the
assertion by a switch(condition, expr, -inf). This is used for logp graphs!
Note that check_parameter should not be used to enforce the logic of the
expression under the normal parameter support as it can be disabled by the user via
check_bounds = False in pm.Model()
"""
Expand All @@ -65,7 +73,8 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str =
cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions
]
all_true_scalar = at.all([at.all(cond) for cond in conditions_])
return CheckParameterValue(msg)(logp, all_true_scalar)

return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)


def logpow(x, m):
Expand Down
5 changes: 4 additions & 1 deletion pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ class CheckParameterValue(CheckAndRaise):
Raises `ParameterValueError` if the check is not True.
"""

def __init__(self, msg=""):
__props__ = ("msg", "exc_type", "can_be_replaced_by_ninf")

def __init__(self, msg: str = "", can_be_replaced_by_ninf: bool = False):
super().__init__(ParameterValueError, msg)
self.can_be_replaced_by_ninf = can_be_replaced_by_ninf

def __str__(self):
return f"Check{{{self.msg}}}"
Expand Down
24 changes: 13 additions & 11 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,19 +913,21 @@ def local_remove_check_parameter(fgraph, node):

@node_rewriter(tracks=[CheckParameterValue])
def local_check_parameter_to_ninf_switch(fgraph, node):
if isinstance(node.op, CheckParameterValue):
logp_expr, *logp_conds = node.inputs
if len(logp_conds) > 1:
logp_cond = at.all(logp_conds)
else:
(logp_cond,) = logp_conds
out = at.switch(logp_cond, logp_expr, -np.inf)
out.name = node.op.msg
if not node.op.can_be_replaced_by_ninf:
return None

logp_expr, *logp_conds = node.inputs
if len(logp_conds) > 1:
logp_cond = at.all(logp_conds)
else:
(logp_cond,) = logp_conds
out = at.switch(logp_cond, logp_expr, -np.inf)
out.name = node.op.msg

if out.dtype != node.outputs[0].dtype:
out = at.cast(out, node.outputs[0].dtype)
if out.dtype != node.outputs[0].dtype:
out = at.cast(out, node.outputs[0].dtype)

return [out]
return [out]


pytensor.compile.optdb["canonicalize"].register(
Expand Down
15 changes: 15 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,21 @@ def test_check_bounds_flag(self):
with m:
assert np.all(compile_pymc([], bound)() == -np.inf)

def test_check_parameters_can_be_replaced_by_ninf(self):
expr = at.vector("expr", shape=(3,))
cond = at.ge(expr, 0)

final_expr = check_parameters(expr, cond, can_be_replaced_by_ninf=True)
fn = compile_pymc([expr], final_expr)
np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3])
np.testing.assert_array_equal(fn(expr=[-1, 2, 3]), [-np.inf, -np.inf, -np.inf])

final_expr = check_parameters(expr, cond, msg="test", can_be_replaced_by_ninf=False)
fn = compile_pymc([expr], final_expr)
np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3])
with pytest.raises(ParameterValueError, match="test"):
fn([-1, 2, 3])

def test_compile_pymc_sets_rng_updates(self):
rng = pytensor.shared(np.random.default_rng(0))
x = pm.Normal.dist(rng=rng)
Expand Down

0 comments on commit f043ad9

Please sign in to comment.