Skip to content

Commit

Permalink
Fix bug in switch mixture logp
Browse files Browse the repository at this point in the history
The True and False branches were being mixed up
  • Loading branch information
ricardoV94 committed Jun 9, 2023
1 parent 8b5f437 commit 864ecb3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def find_measurable_switch_mixture(fgraph, node):
old_mixture_rv.broadcastable,
)
new_mixture_rv = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components)
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1])
).default_output()

if pytensor.config.compute_test_value != "off":
Expand Down
14 changes: 11 additions & 3 deletions tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,12 @@ def test_switch_mixture():
i_vv = I_rv.clone()
i_vv.name = "i"

# When I_rv == True, X_rv flows through otherwise Y_rv does
Z1_rv = pt.switch(I_rv, X_rv, Y_rv)

assert Z1_rv.eval({I_rv: 0}) > 5
assert Z1_rv.eval({I_rv: 1}) < -5

z_vv = Z1_rv.clone()
z_vv.name = "z1"

Expand All @@ -935,7 +940,10 @@ def test_switch_mixture():

# building the identical graph but with a stack to check that mixture computations are identical

Z2_rv = pt.stack((X_rv, Y_rv))[I_rv]
Z2_rv = pt.stack((Y_rv, X_rv))[I_rv]

assert Z2_rv.eval({I_rv: 0}) > 5
assert Z2_rv.eval({I_rv: 1}) < -5

fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})

Expand All @@ -949,8 +957,8 @@ def test_switch_mixture():
# below should follow immediately from the equal_computations assertion above
assert equal_computations([z1_logp_combined], [z2_logp_combined])

np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 0}))
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 0}))
np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1}))
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1}))


def test_ifelse_mixture_one_component():
Expand Down

0 comments on commit 864ecb3

Please sign in to comment.