diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index de1d5a9a16..b9bf2283c1 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -445,6 +445,10 @@ def logprob_MixtureRV( logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m) else: + # FIXME: This logprob implementation does not support mixing across distinct components, + # but we sometimes use it, because MixtureRV does not keep information about at which + # dimension scalar indexing actually starts + # If the stacking operation expands the component RVs, we have # to expand the value and later squeeze the logprob for everything # to work correctly diff --git a/pymc/tests/logprob/test_mixture.py b/pymc/tests/logprob/test_mixture.py index 8deff51982..683e6b2fa1 100644 --- a/pymc/tests/logprob/test_mixture.py +++ b/pymc/tests/logprob/test_mixture.py @@ -306,6 +306,31 @@ def test_hetero_mixture_binomial(p_val, size): (slice(None),), 1, ), + # Vector mixture components, scalar index that mixes across components + pytest.param( + ( + np.array(0, dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + ( + np.array(0.5, dtype=pytensor.config.floatX), + np.array(0.5, dtype=pytensor.config.floatX), + ), + ( + np.array(100, dtype=pytensor.config.floatX), + np.array(1, dtype=pytensor.config.floatX), + ), + np.array([0.1, 0.5, 0.1, 0.3], dtype=pytensor.config.floatX), + (4,), + (), + (), + 1, + marks=pytest.mark.xfail( + AssertionError, + match="Arrays are not almost equal to 6 decimals", # This is ignored, but that's where it should fail! + reason="IfElse Mixture logprob fails when indexing mixes across components", + ), + ), # Matrix components, scalar index along first axis ( (