From 90b6bec083d162080eb5eea1cab291790b4350e2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 15 Dec 2022 18:10:14 +0100 Subject: [PATCH] Do not treat broadcasted variables as independent in logprob inference This included two cases: ## Direct valuation of broadcasted RVs The `naive_bcast_lift` rewrite was included by default and allowed broadcasted RVs to be valued. This is invalid because it implies that `logp(broadcast_to(normal(0, 1), (3, 2), value) == logp(normal(0, 1, size=(3, 2)), value)` which is not true. Broadcast replicates the same RV draws, so these values can't be considered independent when evaluating the logp. The rewrite is kept but not used anywhere ## Valuation of Mixtures with potential repeated components This can happen when AdvancedIndexing is used. As a precaution, Mixture replace now fails when Advanced integer indexing is detected, even though some cases may be valid at runtime (e.g., no repated indexes) --- pymc/logprob/mixture.py | 15 ++++++- pymc/logprob/tensor.py | 4 -- pymc/tests/logprob/test_mixture.py | 63 ++++++++++++++++++++++-------- pymc/tests/logprob/test_tensor.py | 20 ++++++++++ 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index b9bf2283c1..30f271a428 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -58,13 +58,15 @@ ) from pytensor.tensor.shape import shape_tuple from pytensor.tensor.subtensor import ( + AdvancedSubtensor, + AdvancedSubtensor1, as_index_literal, as_nontensor_scalar, get_canonical_form_slice, is_basic_idx, ) from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType +from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType from pytensor.tensor.var import TensorVariable from pymc.logprob.abstract import ( @@ -309,6 +311,17 @@ def mixture_replace(fgraph, node): mixing_indices = node.inputs[1:] + # TODO: Add check / test case for Advanced Boolean indexing + if isinstance(node.op, (AdvancedSubtensor, AdvancedSubtensor1)): + # We don't support (non-scalar) integer array indexing as it can pick repeated values, + # but the Mixture logprob assumes all mixture values are independent + if any( + indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0 + for indices in mixing_indices + if not isinstance(indices, SliceConstant) + ): + return None + # We loop through mixture components and collect all the array elements # that belong to each one (by way of their indices). new_mixture_rvs = [] diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 742541de99..f52bbe9492 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -357,10 +357,6 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf "find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor" ) - -measurable_ir_rewrites_db.register("broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor") - - measurable_ir_rewrites_db.register( "find_measurable_stacks", find_measurable_stacks, diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index 683e6b2fa1..57d44d89c9 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -91,7 +91,7 @@ def create_mix_model(size, axis): with pytest.raises(RuntimeError, match="could not be derived: {m}"): factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv, X_rv: x_vv}) - with pytest.raises(NotImplementedError): + with pytest.raises(RuntimeError, match="could not be derived: {m}"): axis_at = at.lscalar("axis") axis_at.tag.test_value = 0 env = create_mix_model((2,), axis_at) @@ -139,17 +139,19 @@ def test_compute_test_value(op_constructor): @pytest.mark.parametrize( - "p_val, size", + "p_val, size, supported", [ - (np.array(0.0, dtype=pytensor.config.floatX), ()), - (np.array(1.0, dtype=pytensor.config.floatX), ()), - (np.array(0.0, dtype=pytensor.config.floatX), (2,)), - (np.array(1.0, dtype=pytensor.config.floatX), (2, 1)), - (np.array(1.0, dtype=pytensor.config.floatX), (2, 3)), - (np.array([0.1, 0.9], dtype=pytensor.config.floatX), (2, 3)), + (np.array(0.0, dtype=pytensor.config.floatX), (), True), + (np.array(1.0, dtype=pytensor.config.floatX), (), True), + (np.array([0.1, 0.9], dtype=pytensor.config.floatX), (), True), + # The cases belowe are not supported because they may pick repeated values via AdvancedIndexing + (np.array(0.0, dtype=pytensor.config.floatX), (2,), False), + (np.array(1.0, dtype=pytensor.config.floatX), (2, 1), False), + (np.array(1.0, dtype=pytensor.config.floatX), (2, 3), False), + (np.array([0.1, 0.9], dtype=pytensor.config.floatX), (2, 3), False), ], ) -def test_hetero_mixture_binomial(p_val, size): +def test_hetero_mixture_binomial(p_val, size, supported): srng = at.random.RandomStream(29833) X_rv = srng.normal(0, 1, size=size, name="X") @@ -175,7 +177,12 @@ def test_hetero_mixture_binomial(p_val, size): m_vv = M_rv.clone() m_vv.name = "m" - M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + if supported: + M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + else: + with pytest.raises(RuntimeError, match="could not be derived: {m}"): + joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + return M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp) @@ -204,9 +211,9 @@ def test_hetero_mixture_binomial(p_val, size): @pytest.mark.parametrize( - "X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis", + "X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported", [ - # Scalar mixture components, scalar index + # Scalar components, scalar index ( ( np.array(0, dtype=pytensor.config.floatX), @@ -225,6 +232,7 @@ def test_hetero_mixture_binomial(p_val, size): (), (), 0, + True, ), # Degenerate vector mixture components, scalar index along join axis ( @@ -245,6 +253,7 @@ def test_hetero_mixture_binomial(p_val, size): (), (), 0, + True, ), # Degenerate vector mixture components, scalar index along join axis (axis=1) ( @@ -265,6 +274,7 @@ def test_hetero_mixture_binomial(p_val, size): (), (slice(None),), 1, + True, ), # Vector mixture components, scalar index along the join axis ( @@ -285,6 +295,7 @@ def test_hetero_mixture_binomial(p_val, size): (), (), 0, + True, ), # Vector mixture components, scalar index along the join axis (axis=1) ( @@ -305,6 +316,7 @@ def test_hetero_mixture_binomial(p_val, size): (), (slice(None),), 1, + True, ), # Vector mixture components, scalar index that mixes across components pytest.param( @@ -325,6 +337,7 @@ def test_hetero_mixture_binomial(p_val, size): (), (), 1, + True, marks=pytest.mark.xfail( AssertionError, match="Arrays are not almost equal to 6 decimals", # This is ignored, but that's where it should fail! @@ -350,7 +363,10 @@ def test_hetero_mixture_binomial(p_val, size): (), (), 0, + True, ), + # All the tests below rely on AdvancedIndexing, which is not supported at the moment + # See https://github.com/pymc-devs/pymc/issues/6398 # Scalar mixture components, vector index along first axis ( ( @@ -370,6 +386,7 @@ def test_hetero_mixture_binomial(p_val, size): (6,), (), 0, + False, ), # Vector mixture components, vector index along first axis ( @@ -390,9 +407,10 @@ def test_hetero_mixture_binomial(p_val, size): (2,), (slice(None),), 0, + False, ), # Vector mixture components, vector index along last axis - pytest.param( + ( ( np.array(0, dtype=pytensor.config.floatX), np.array(1, dtype=pytensor.config.floatX), @@ -410,7 +428,7 @@ def test_hetero_mixture_binomial(p_val, size): (4,), (slice(None),), 1, - marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"), + False, ), # Vector mixture components (with degenerate vector parameters), vector index along first axis ( @@ -431,6 +449,7 @@ def test_hetero_mixture_binomial(p_val, size): (2,), (), 0, + False, ), # Vector mixture components (with vector parameters), vector index along first axis ( @@ -451,6 +470,7 @@ def test_hetero_mixture_binomial(p_val, size): (2,), (), 0, + False, ), # Vector mixture components (with vector parameters), vector index along first axis, implicit sizes ( @@ -471,6 +491,7 @@ def test_hetero_mixture_binomial(p_val, size): None, (), 0, + False, ), # Matrix mixture components, matrix index ( @@ -491,6 +512,7 @@ def test_hetero_mixture_binomial(p_val, size): (2, 3), (), 0, + False, ), # Vector components, matrix indexing (constant along first dimension, then random) ( @@ -511,6 +533,7 @@ def test_hetero_mixture_binomial(p_val, size): (5,), (np.arange(5),), 0, + False, ), # Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random) ( @@ -531,11 +554,12 @@ def test_hetero_mixture_binomial(p_val, size): (5,), (np.arange(5), None), 0, + False, ), ], ) def test_hetero_mixture_categorical( - X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis + X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported ): srng = at.random.RandomStream(29833) @@ -561,7 +585,12 @@ def test_hetero_mixture_categorical( m_vv = M_rv.clone() m_vv.name = "m" - logp_parts = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + if supported: + logp_parts = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + else: + with pytest.raises(RuntimeError, match="could not be derived: {m}"): + factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + return I_logp_fn = pytensor.function([p_at, i_vv], logp_parts[i_vv]) M_logp_fn = pytensor.function([m_vv, i_vv], logp_parts[m_vv]) @@ -854,7 +883,7 @@ def test_mixture_with_DiracDelta(): Y_rv = dirac_delta(0.0) Y_rv.name = "Y" - I_rv = srng.categorical([0.5, 0.5], size=4) + I_rv = srng.categorical([0.5, 0.5], size=1) i_vv = I_rv.clone() i_vv.name = "i" diff --git a/pymc/tests/logprob/test_tensor.py b/pymc/tests/logprob/test_tensor.py index 198f06cabd..4b944d64df 100644 --- a/pymc/tests/logprob/test_tensor.py +++ b/pymc/tests/logprob/test_tensor.py @@ -81,6 +81,26 @@ def test_naive_bcast_rv_lift_valued_var(): assert np.allclose(logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])) +@pytest.mark.xfail(RuntimeError, reason="logprob for broadcasted RVs not implemented") +def test_bcast_rv_logp(): + """Test that derived logp for broadcasted RV is correct""" + + x_rv = at.random.normal(name="x") + broadcasted_x_rv = at.broadcast_to(x_rv, (2,)) + broadcasted_x_rv.name = "broadcasted_x" + broadcasted_x_vv = broadcasted_x_rv.clone() + + logp = joint_logprob({broadcasted_x_rv: broadcasted_x_vv}, sum=False) + valid_logp = logp.eval({broadcasted_x_vv: [0, 0]}) + assert valid_logp.shape == () + assert np.isclose(valid_logp, st.norm.logpdf(0)) + + # It's not possible for broadcasted dimensions to have different values + # This shoud either raise or return -inf + invalid_logp = logp.eval({broadcasted_x_vv: [0, 1]}) + assert invalid_logp == -np.inf + + def test_measurable_make_vector(): base1_rv = at.random.normal(name="base1") base2_rv = at.random.halfnormal(name="base2")