Skip to content

Commit

Permalink
Do not treat broadcasted variables as independent in logprob inference
Browse files Browse the repository at this point in the history
This included two cases:

## Direct valuation of broadcasted RVs

The `naive_bcast_lift` rewrite was included by default and allowed broadcasted RVs to be valued. This is invalid because it implies that `logp(broadcast_to(normal(0, 1), (3, 2), value) == logp(normal(0, 1, size=(3, 2)), value)` which is not true. Broadcast replicates the same RV draws, so these values can't be considered independent when evaluating the logp.

The rewrite is kept but not used anywhere

## Valuation of Mixtures with potential repeated components

This can happen when AdvancedIndexing is used. As a precaution, Mixture replace now fails when Advanced integer indexing is detected, even though some cases may be valid at runtime (e.g., no repated indexes)
  • Loading branch information
ricardoV94 authored and twiecki committed Dec 16, 2022
1 parent 762de98 commit 90b6bec
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 22 deletions.
15 changes: 14 additions & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@
)
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
as_index_literal,
as_nontensor_scalar,
get_canonical_form_slice,
is_basic_idx,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import (
Expand Down Expand Up @@ -309,6 +311,17 @@ def mixture_replace(fgraph, node):

mixing_indices = node.inputs[1:]

# TODO: Add check / test case for Advanced Boolean indexing
if isinstance(node.op, (AdvancedSubtensor, AdvancedSubtensor1)):
# We don't support (non-scalar) integer array indexing as it can pick repeated values,
# but the Mixture logprob assumes all mixture values are independent
if any(
indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
for indices in mixing_indices
if not isinstance(indices, SliceConstant)
):
return None

# We loop through mixture components and collect all the array elements
# that belong to each one (by way of their indices).
new_mixture_rvs = []
Expand Down
4 changes: 0 additions & 4 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,6 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf
"find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor"
)


measurable_ir_rewrites_db.register("broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor")


measurable_ir_rewrites_db.register(
"find_measurable_stacks",
find_measurable_stacks,
Expand Down
63 changes: 46 additions & 17 deletions pymc/tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def create_mix_model(size, axis):
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv, X_rv: x_vv})

with pytest.raises(NotImplementedError):
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
axis_at = at.lscalar("axis")
axis_at.tag.test_value = 0
env = create_mix_model((2,), axis_at)
Expand Down Expand Up @@ -139,17 +139,19 @@ def test_compute_test_value(op_constructor):


@pytest.mark.parametrize(
"p_val, size",
"p_val, size, supported",
[
(np.array(0.0, dtype=pytensor.config.floatX), ()),
(np.array(1.0, dtype=pytensor.config.floatX), ()),
(np.array(0.0, dtype=pytensor.config.floatX), (2,)),
(np.array(1.0, dtype=pytensor.config.floatX), (2, 1)),
(np.array(1.0, dtype=pytensor.config.floatX), (2, 3)),
(np.array([0.1, 0.9], dtype=pytensor.config.floatX), (2, 3)),
(np.array(0.0, dtype=pytensor.config.floatX), (), True),
(np.array(1.0, dtype=pytensor.config.floatX), (), True),
(np.array([0.1, 0.9], dtype=pytensor.config.floatX), (), True),
# The cases belowe are not supported because they may pick repeated values via AdvancedIndexing
(np.array(0.0, dtype=pytensor.config.floatX), (2,), False),
(np.array(1.0, dtype=pytensor.config.floatX), (2, 1), False),
(np.array(1.0, dtype=pytensor.config.floatX), (2, 3), False),
(np.array([0.1, 0.9], dtype=pytensor.config.floatX), (2, 3), False),
],
)
def test_hetero_mixture_binomial(p_val, size):
def test_hetero_mixture_binomial(p_val, size, supported):
srng = at.random.RandomStream(29833)

X_rv = srng.normal(0, 1, size=size, name="X")
Expand All @@ -175,7 +177,12 @@ def test_hetero_mixture_binomial(p_val, size):
m_vv = M_rv.clone()
m_vv.name = "m"

M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
if supported:
M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
else:
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
return

M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp)

Expand Down Expand Up @@ -204,9 +211,9 @@ def test_hetero_mixture_binomial(p_val, size):


@pytest.mark.parametrize(
"X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis",
"X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported",
[
# Scalar mixture components, scalar index
# Scalar components, scalar index
(
(
np.array(0, dtype=pytensor.config.floatX),
Expand All @@ -225,6 +232,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(),
0,
True,
),
# Degenerate vector mixture components, scalar index along join axis
(
Expand All @@ -245,6 +253,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(),
0,
True,
),
# Degenerate vector mixture components, scalar index along join axis (axis=1)
(
Expand All @@ -265,6 +274,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(slice(None),),
1,
True,
),
# Vector mixture components, scalar index along the join axis
(
Expand All @@ -285,6 +295,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(),
0,
True,
),
# Vector mixture components, scalar index along the join axis (axis=1)
(
Expand All @@ -305,6 +316,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(slice(None),),
1,
True,
),
# Vector mixture components, scalar index that mixes across components
pytest.param(
Expand All @@ -325,6 +337,7 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(),
1,
True,
marks=pytest.mark.xfail(
AssertionError,
match="Arrays are not almost equal to 6 decimals", # This is ignored, but that's where it should fail!
Expand All @@ -350,7 +363,10 @@ def test_hetero_mixture_binomial(p_val, size):
(),
(),
0,
True,
),
# All the tests below rely on AdvancedIndexing, which is not supported at the moment
# See https://github.com/pymc-devs/pymc/issues/6398
# Scalar mixture components, vector index along first axis
(
(
Expand All @@ -370,6 +386,7 @@ def test_hetero_mixture_binomial(p_val, size):
(6,),
(),
0,
False,
),
# Vector mixture components, vector index along first axis
(
Expand All @@ -390,9 +407,10 @@ def test_hetero_mixture_binomial(p_val, size):
(2,),
(slice(None),),
0,
False,
),
# Vector mixture components, vector index along last axis
pytest.param(
(
(
np.array(0, dtype=pytensor.config.floatX),
np.array(1, dtype=pytensor.config.floatX),
Expand All @@ -410,7 +428,7 @@ def test_hetero_mixture_binomial(p_val, size):
(4,),
(slice(None),),
1,
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"),
False,
),
# Vector mixture components (with degenerate vector parameters), vector index along first axis
(
Expand All @@ -431,6 +449,7 @@ def test_hetero_mixture_binomial(p_val, size):
(2,),
(),
0,
False,
),
# Vector mixture components (with vector parameters), vector index along first axis
(
Expand All @@ -451,6 +470,7 @@ def test_hetero_mixture_binomial(p_val, size):
(2,),
(),
0,
False,
),
# Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
(
Expand All @@ -471,6 +491,7 @@ def test_hetero_mixture_binomial(p_val, size):
None,
(),
0,
False,
),
# Matrix mixture components, matrix index
(
Expand All @@ -491,6 +512,7 @@ def test_hetero_mixture_binomial(p_val, size):
(2, 3),
(),
0,
False,
),
# Vector components, matrix indexing (constant along first dimension, then random)
(
Expand All @@ -511,6 +533,7 @@ def test_hetero_mixture_binomial(p_val, size):
(5,),
(np.arange(5),),
0,
False,
),
# Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
(
Expand All @@ -531,11 +554,12 @@ def test_hetero_mixture_binomial(p_val, size):
(5,),
(np.arange(5), None),
0,
False,
),
],
)
def test_hetero_mixture_categorical(
X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis
X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported
):
srng = at.random.RandomStream(29833)

Expand All @@ -561,7 +585,12 @@ def test_hetero_mixture_categorical(
m_vv = M_rv.clone()
m_vv.name = "m"

logp_parts = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
if supported:
logp_parts = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
else:
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
return

I_logp_fn = pytensor.function([p_at, i_vv], logp_parts[i_vv])
M_logp_fn = pytensor.function([m_vv, i_vv], logp_parts[m_vv])
Expand Down Expand Up @@ -854,7 +883,7 @@ def test_mixture_with_DiracDelta():
Y_rv = dirac_delta(0.0)
Y_rv.name = "Y"

I_rv = srng.categorical([0.5, 0.5], size=4)
I_rv = srng.categorical([0.5, 0.5], size=1)

i_vv = I_rv.clone()
i_vv.name = "i"
Expand Down
20 changes: 20 additions & 0 deletions pymc/tests/logprob/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ def test_naive_bcast_rv_lift_valued_var():
assert np.allclose(logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]))


@pytest.mark.xfail(RuntimeError, reason="logprob for broadcasted RVs not implemented")
def test_bcast_rv_logp():
"""Test that derived logp for broadcasted RV is correct"""

x_rv = at.random.normal(name="x")
broadcasted_x_rv = at.broadcast_to(x_rv, (2,))
broadcasted_x_rv.name = "broadcasted_x"
broadcasted_x_vv = broadcasted_x_rv.clone()

logp = joint_logprob({broadcasted_x_rv: broadcasted_x_vv}, sum=False)
valid_logp = logp.eval({broadcasted_x_vv: [0, 0]})
assert valid_logp.shape == ()
assert np.isclose(valid_logp, st.norm.logpdf(0))

# It's not possible for broadcasted dimensions to have different values
# This shoud either raise or return -inf
invalid_logp = logp.eval({broadcasted_x_vv: [0, 1]})
assert invalid_logp == -np.inf


def test_measurable_make_vector():
base1_rv = at.random.normal(name="base1")
base2_rv = at.random.halfnormal(name="base2")
Expand Down

0 comments on commit 90b6bec

Please sign in to comment.