diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index fbdea97440..93d77615df 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -50,13 +50,21 @@ } -def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = ""): - """ - Wrap a log probability graph in a CheckParameterValue that asserts several - conditions are True. When conditions are not met a ParameterValueError assertion is - raised, with an optional custom message defined by `msg` +def check_parameters( + expr: Variable, + *conditions: Iterable[Variable], + msg: str = "", + can_be_replaced_by_ninf: bool = True, +): + """Wrap an expression in a CheckParameterValue that asserts several conditions are met. + + When conditions are not met a ParameterValueError assertion is raised, + with an optional custom message defined by `msg`. - Note that check_parameter should not be used to enforce the logic of the logp + When the flag `can_be_replaced_by_ninf` is True (default), PyMC is allowed to replace the + assertion by a switch(condition, expr, -inf). This is used for logp graphs! + + Note that check_parameter should not be used to enforce the logic of the expression under the normal parameter support as it can be disabled by the user via check_bounds = False in pm.Model() """ @@ -65,7 +73,8 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions ] all_true_scalar = at.all([at.all(cond) for cond in conditions_]) - return CheckParameterValue(msg)(logp, all_true_scalar) + + return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar) def logpow(x, m): diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index b88d56d3ee..b93962fd56 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -210,8 +210,11 @@ class CheckParameterValue(CheckAndRaise): Raises `ParameterValueError` if the check is not True. """ - def __init__(self, msg=""): + __props__ = ("msg", "exc_type", "can_be_replaced_by_ninf") + + def __init__(self, msg: str = "", can_be_replaced_by_ninf: bool = False): super().__init__(ParameterValueError, msg) + self.can_be_replaced_by_ninf = can_be_replaced_by_ninf def __str__(self): return f"Check{{{self.msg}}}" diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index bca1c7bdca..033fc8aefa 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -913,19 +913,21 @@ def local_remove_check_parameter(fgraph, node): @node_rewriter(tracks=[CheckParameterValue]) def local_check_parameter_to_ninf_switch(fgraph, node): - if isinstance(node.op, CheckParameterValue): - logp_expr, *logp_conds = node.inputs - if len(logp_conds) > 1: - logp_cond = at.all(logp_conds) - else: - (logp_cond,) = logp_conds - out = at.switch(logp_cond, logp_expr, -np.inf) - out.name = node.op.msg + if not node.op.can_be_replaced_by_ninf: + return None + + logp_expr, *logp_conds = node.inputs + if len(logp_conds) > 1: + logp_cond = at.all(logp_conds) + else: + (logp_cond,) = logp_conds + out = at.switch(logp_cond, logp_expr, -np.inf) + out.name = node.op.msg - if out.dtype != node.outputs[0].dtype: - out = at.cast(out, node.outputs[0].dtype) + if out.dtype != node.outputs[0].dtype: + out = at.cast(out, node.outputs[0].dtype) - return [out] + return [out] pytensor.compile.optdb["canonicalize"].register( diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index ddaf86ed47..0fb90c6a51 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -326,6 +326,21 @@ def test_check_bounds_flag(self): with m: assert np.all(compile_pymc([], bound)() == -np.inf) + def test_check_parameters_can_be_replaced_by_ninf(self): + expr = at.vector("expr", shape=(3,)) + cond = at.ge(expr, 0) + + final_expr = check_parameters(expr, cond, can_be_replaced_by_ninf=True) + fn = compile_pymc([expr], final_expr) + np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3]) + np.testing.assert_array_equal(fn(expr=[-1, 2, 3]), [-np.inf, -np.inf, -np.inf]) + + final_expr = check_parameters(expr, cond, msg="test", can_be_replaced_by_ninf=False) + fn = compile_pymc([expr], final_expr) + np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3]) + with pytest.raises(ParameterValueError, match="test"): + fn([-1, 2, 3]) + def test_compile_pymc_sets_rng_updates(self): rng = pytensor.shared(np.random.default_rng(0)) x = pm.Normal.dist(rng=rng)